From 0c083a556c44732a9b6cc7f80b65765e657509ec Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 3 Jun 2026 16:09:29 +0800 Subject: [PATCH 01/18] feat: nv nlist --- backend/find_pytorch.py | 3 + deepmd/pt/infer/deep_eval.py | 119 ++++--- deepmd/pt/utils/nv_nlist.py | 334 ++++++++++++++++++++ source/tests/pt/model/test_nlist_backend.py | 121 ++++--- source/tests/pt/model/test_nv_nlist.py | 225 +++++++++++++ 5 files changed, 721 insertions(+), 81 deletions(-) create mode 100644 deepmd/pt/utils/nv_nlist.py create mode 100644 source/tests/pt/model/test_nv_nlist.py diff --git a/backend/find_pytorch.py b/backend/find_pytorch.py index bccb2908f1..48361d86af 100644 --- a/backend/find_pytorch.py +++ b/backend/find_pytorch.py @@ -142,6 +142,9 @@ def get_pt_requirement(pt_version: str = "") -> dict: # under the torch extra rather than the core deps (conda-forge has # vesin but not vesin-torch). "vesin[torch]", + # GPU O(N) cell-list neighbor list for large systems; the package + # requires Python >= 3.11 while deepmd-kit still supports 3.10. + "nvalchemi-toolkit-ops>=0.3.1; python_version >= '3.11'", *mpi_requirement, *cibw_requirement, ], diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index e398ac9210..640b5fa918 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -66,6 +66,10 @@ GLOBAL_PT_FLOAT_PRECISION, RESERVED_PRECISION_DICT, ) +from deepmd.pt.utils.nv_nlist import ( + NvNeighborList, + is_nv_available, +) from deepmd.pt.utils.utils import ( to_numpy_array, to_torch_tensor, @@ -255,63 +259,88 @@ def __init__( def _setup_nlist_backend(self, nlist_backend: str) -> None: """Resolve the neighbor-list construction strategy from a user choice. - ``"native"`` uses the dense all-pairs builder; ``"vesin"`` forces the - O(N) ``vesin.torch`` cell list (raising if it is unavailable or the - model/inputs are unsupported); ``"auto"`` uses vesin when applicable and - silently falls back to the native builder otherwise. Results are - unchanged either way -- only the neighbor-search cost differs. + ``"native"`` uses the dense all-pairs builder; ``"vesin"`` / ``"nv"`` + force the O(N) ``vesin.torch`` / ``nvalchemiops`` cell list (raising if + unavailable or the model/inputs are unsupported); ``"auto"`` picks the + first available O(N) builder (vesin, then nv) and otherwise falls back to + the native builder. Results are unchanged either way -- only the + neighbor-search cost differs. """ - if nlist_backend not in ("auto", "vesin", "native"): + inner = self.dp.model["Default"] + self_built = getattr(inner, "use_self_built_nlist", None) + if callable(self_built) and self_built(): + # The model builds its own neighbor list and runs the native path; + # an external strategy would bypass it, so always use native. + log.info( + "Ignoring nlist_backend=%r: %s uses its own built-in neighbor list.", + nlist_backend, + type(inner).__name__, + ) + self._nlist_builder = None + return + if nlist_backend not in ("auto", "vesin", "nv", "native"): raise ValueError( f"Unknown nlist_backend '{nlist_backend}'; " - "expected 'auto', 'vesin', or 'native'." + "expected 'auto', 'vesin', 'nv', or 'native'." ) - # reason vesin cannot be used (None means it can) + + # reason an external strategy cannot be used (None means it can) unsupported = None if self._has_spin: unsupported = "spin models" elif self._has_hessian: unsupported = "hessian models" elif self.modifier is not None: - # the vesin path runs forward_common_lower directly, bypassing + # the strategy path runs forward_common_lower directly, bypassing # ModelWrapper.forward (which applies the data modifier); fall back # to the native path so the modifier is still applied. unsupported = "models with a data modifier" - elif "energy" not in self.dp.model["Default"].model_output_type(): - # _eval_lower_vesin reconstructs the backend output from the + elif "energy" not in inner.model_output_type(): + # _eval_lower_strategy reconstructs the backend output from the # forward_common_lower / communicate keys via _OUTDEF_DP2BACKEND, # which matches the model's own translation only for the energy # model (e.g. the polar fitting key is "polarizability" but the - # backend output is "polar"). Restrict vesin to energy models -- - # the large-system inference target -- and fall back to native - # for the other fitting types. + # backend output is "polar"). Restrict strategies to energy models + # and fall back to native for the other fitting types. unsupported = "non-energy models" ase_provided = self.neighbor_list is not None - if nlist_backend == "native": - self._use_vesin = False - elif nlist_backend == "vesin": - if not is_vesin_torch_available(): - raise ImportError( - "nlist_backend='vesin' was requested but 'vesin.torch' is " - "not installed. Install it (`pip install vesin[torch]`) or " - "use nlist_backend='native' (or 'auto')." - ) + + builder = None + if nlist_backend in ("vesin", "nv"): if unsupported is not None: raise ValueError( - f"nlist_backend='vesin' is not supported for {unsupported}; " - "use nlist_backend='native' (or 'auto')." + f"nlist_backend='{nlist_backend}' is not supported for " + f"{unsupported}; use nlist_backend='native' (or 'auto')." ) if ase_provided: raise ValueError( - "nlist_backend='vesin' conflicts with an explicitly " - "supplied ASE neighbor_list; pass only one." + f"nlist_backend='{nlist_backend}' conflicts with an " + "explicitly supplied ASE neighbor_list; pass only one." ) - self._use_vesin = True - else: # auto: use vesin when possible, otherwise fall back silently - self._use_vesin = ( - is_vesin_torch_available() and unsupported is None and not ase_provided - ) - self._nlist_builder = VesinNeighborList() if self._use_vesin else None + if nlist_backend == "vesin": + if not is_vesin_torch_available(): + raise ImportError( + "nlist_backend='vesin' was requested but 'vesin.torch' " + "is not installed. Install it (`pip install " + "vesin[torch]`) or use nlist_backend='native' (or 'auto')." + ) + builder = VesinNeighborList() + elif not is_nv_available(): + raise ImportError( + "nlist_backend='nv' was requested but 'nvalchemi-toolkit-ops'" + " is not installed. Install it (`pip install " + "nvalchemi-toolkit-ops`) or use nlist_backend='native' " + "(or 'auto')." + ) + else: + builder = NvNeighborList() + elif nlist_backend == "auto" and unsupported is None and not ase_provided: + # Pick the first available O(N) builder; nv is GPU-only. + if is_vesin_torch_available(): + builder = VesinNeighborList() + elif is_nv_available() and torch.cuda.is_available(): + builder = NvNeighborList() + self._nlist_builder = builder def get_rcut(self) -> float: """Get the cutoff radius of this model.""" @@ -659,8 +688,8 @@ def _eval_model( do_atomic_virial = any( x.category == OutputVariableCategory.DERV_C for x in request_defs ) - if self._use_vesin: - batch_output = self._eval_lower_vesin( + if self._nlist_builder is not None: + batch_output = self._eval_lower_strategy( coord_input, type_input, box_input, @@ -696,7 +725,7 @@ def _eval_model( ) # this is kinda hacky return tuple(results) - def _eval_lower_vesin( + def _eval_lower_strategy( self, coord: torch.Tensor, atype: torch.Tensor, @@ -706,15 +735,15 @@ def _eval_lower_vesin( charge_spin: torch.Tensor | None, do_atomic_virial: bool, ) -> dict[str, torch.Tensor]: - """Evaluate via the O(N) vesin-built ``(i,j,S)`` extended neighbor list. - - Builds the extended representation with the vesin cell list, runs the - model's ``forward_common_lower``, and maps the extended outputs back to - local atoms with ``communicate_extended_output``. Returns a dict keyed - by backend names, matching the normal ``model()`` output so the caller's - extraction is unchanged. ``forward_common_atomic`` sets - ``requires_grad`` on the extended coordinates internally, exactly as on - the native path, so forces/virials are produced identically. + """Evaluate via the selected O(N) ``NeighborList`` strategy. + + Builds the extended representation with ``self._nlist_builder`` (vesin or + nv), runs the model's ``forward_common_lower``, and maps the extended + outputs back to local atoms with ``communicate_extended_output``. + Returns a dict keyed by backend names, matching the normal ``model()`` + output so the caller's extraction is unchanged. ``requires_grad`` is set + on the extended coordinates internally, exactly as on the native path, so + forces/virials are produced identically. """ inner = self.dp.model["Default"] ext_coord, ext_atype, nlist, mapping = self._nlist_builder.build( diff --git a/deepmd/pt/utils/nv_nlist.py b/deepmd/pt/utils/nv_nlist.py new file mode 100644 index 0000000000..460fcadbb1 --- /dev/null +++ b/deepmd/pt/utils/nv_nlist.py @@ -0,0 +1,334 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Toolkit-Ops (``nvalchemiops``) neighbor-list strategy. + +A :class:`~deepmd.dpmodel.utils.neighbor_list.NeighborList` implementation that +builds the extended representation ``(extended_coord, extended_atype, nlist, +mapping)`` using the device-resident O(N) cell list in ``nvalchemiops``, intended +for large periodic systems. + +Toolkit-Ops returns a dense ``[total_atoms, max_neighbors]`` neighbor matrix over +the flattened batch. The matrix is converted to the DeePMD extended-atom contract +by materializing each unique ghost ``(frame, src_local, shift)`` once; the +candidate list is then distance-sorted and truncated to ``sum(sel)`` so the +returned neighbor count is fixed. The search is non-differentiable and runs on +detached coordinates, while ``extended_coord`` is rebuilt from the input +coordinates so force and virial gradients propagate unchanged. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.utils.neighbor_list import ( + NeighborList, +) +from deepmd.pt.utils.region import ( + normalize_coord, +) + +NV_CELL_LIST_THRESHOLD = 1024 + + +def is_nv_available() -> bool: + """Whether the ``nvalchemiops`` Toolkit-Ops neighbor list is importable.""" + try: + import nvalchemiops.torch.neighbors # noqa: F401 + except ImportError: + return False + return True + + +def choose_nv_nlist_method(nloc: int) -> str: + """Choose the Toolkit-Ops neighbor method for a homogeneous batch. + + Parameters + ---------- + nloc + Number of local atoms per frame. + + Returns + ------- + str + Toolkit-Ops method name. + """ + if nloc >= NV_CELL_LIST_THRESHOLD: + return "batch_cell_list" + return "batch_naive" + + +class NvNeighborList(NeighborList): + """O(N) neighbor-list strategy using the ``nvalchemiops`` cell list. + + Implements the :class:`~deepmd.dpmodel.utils.neighbor_list.NeighborList` + interface on torch tensors; the search runs on the device of the input + coordinates. A periodic ``box`` is required -- the cell list needs a cell to + wrap periodic images. + """ + + def build( + self, + coord: Any, + atype: Any, + box: Any, + rcut: float, + sel: list[int], + ) -> tuple[Any, Any, Any, Any]: + """Build the extended system and neighbor list. + + See :meth:`deepmd.dpmodel.utils.neighbor_list.NeighborList.build`. The + returned ``nlist`` is distance-sorted and truncated to ``sum(sel)``. A + periodic ``box`` is required, as the cell list operates on a periodic cell. + """ + from nvalchemiops.torch.neighbors import ( + neighbor_list, + ) + + if box is None: + raise ValueError("NvNeighborList requires a periodic box; got box=None.") + + nf, nloc = atype.shape[:2] + device = coord.device + target_neighbors = int(sum(sel)) + search_capacity = target_neighbors + total_atoms = nf * nloc + cell = box.reshape(nf, 3, 3).to(device=device, dtype=coord.dtype) + coord = normalize_coord(coord.reshape(nf, nloc, 3), cell) + positions_for_nlist = coord.reshape(total_atoms, 3).detach() + pbc = torch.ones((nf, 3), dtype=torch.bool, device=device) + batch_idx = torch.arange( + nf, dtype=torch.int32, device=device + ).repeat_interleave(nloc) + batch_ptr = torch.arange(nf + 1, dtype=torch.int32, device=device) * nloc + method = choose_nv_nlist_method(nloc) + + # Grow the search capacity until all neighbors fit so the distance-sort + # below selects the true nearest ``sum(sel)``. + while True: + neighbor_matrix, num_neighbors, shifts = neighbor_list( + positions_for_nlist, + float(rcut), + cell=cell, + pbc=pbc, + batch_idx=batch_idx, + batch_ptr=batch_ptr, + method=method, + max_neighbors=int(search_capacity), + return_neighbor_list=False, + wrap_positions=False, + ) + max_found = ( + int(num_neighbors.max().item()) if num_neighbors.numel() > 0 else 0 + ) + if max_found <= search_capacity: + break + search_capacity = max(max_found, _grow_search_capacity(search_capacity)) + + extended_coord, extended_atype, mapping, nlist = _matrix_to_extended_inputs( + coord=coord, + atype=atype, + cell=cell, + nloc=nloc, + neighbor_matrix=neighbor_matrix, + num_neighbors=num_neighbors, + shifts=shifts, + ) + nlist = _truncate_to_sel_compiled( + extended_coord, nlist, target_neighbors, float(rcut) + ) + return extended_coord, extended_atype, nlist, mapping + + +def _grow_search_capacity(capacity: int) -> int: + """Increase Toolkit-Ops capacity by 1.25x, rounded up.""" + return (capacity * 5 + 3) // 4 + + +@torch.no_grad() +def _truncate_to_sel( + extended_coord: torch.Tensor, + nlist: torch.Tensor, + nsel: int, + rcut: float, +) -> torch.Tensor: + """Distance-sort the candidate neighbor list and keep the nearest ``nsel`` + within ``rcut``, padding with ``-1`` when fewer neighbors exist. + + The Toolkit-Ops search capacity may exceed ``sum(sel)`` on dense systems; this + fixes the returned neighbor count at ``nsel``. + + The output is the integer ``nlist``; ``extended_coord`` is only read to rank + candidates and is returned unchanged by the caller. The routine is therefore + non-differentiable and runs under ``no_grad`` so it never participates in the + autograd graph (forward, backward, or the second-order pass used to train + forces), which also avoids retaining the distance temporaries for backward. + """ + nf, nloc, nnei = nlist.shape + if nnei < nsel: + pad = torch.full( + (nf, nloc, nsel - nnei), -1, dtype=nlist.dtype, device=nlist.device + ) + return torch.cat([nlist, pad], dim=-1) + if nnei == nsel: + return nlist + real_neighbor = nlist >= 0 + safe_nlist = torch.where(real_neighbor, nlist, torch.zeros_like(nlist)) + coord0 = extended_coord[:, :nloc, :] + index = safe_nlist.view(nf, nloc * nnei, 1).expand(-1, -1, 3) + coord1 = torch.gather(extended_coord, 1, index).view(nf, nloc, nnei, 3) + rr = torch.linalg.norm(coord1 - coord0[:, :, None, :], dim=-1) + rr = torch.where(real_neighbor, rr, float("inf")) + rr, order = torch.sort(rr, dim=-1) + sorted_nlist = torch.gather(safe_nlist, 2, order) + sorted_nlist = torch.where(rr > rcut, -1, sorted_nlist) + # ``.contiguous()`` is required: the bare ``[..., :nsel]`` slice keeps the + # wider candidate stride, but the compiled lower interface freezes the nlist + # sel axis and asserts a contiguous layout (``assert_size_stride``). + return sorted_nlist[..., :nsel].contiguous() + + +# Lower the gather/distance-sort/mask pipeline of `_truncate_to_sel` into a single +# Inductor graph. ``dynamic=True`` keeps the per-system ``(nf, nloc, nnei)`` shapes +# on one compiled graph instead of recompiling per system size, and fusing the +# pipeline avoids materializing the full ``(nf, nloc, nnei, 3)`` distance +# temporaries, which lowers both this step's peak memory and its latency relative +# to eager. Compilation is lazy: it happens on first call, not at import. +_truncate_to_sel_compiled = torch.compile(_truncate_to_sel, dynamic=True) + + +def _matrix_to_extended_inputs( + *, + coord: torch.Tensor, + atype: torch.Tensor, + cell: torch.Tensor, + nloc: int, + neighbor_matrix: torch.Tensor, + num_neighbors: torch.Tensor, + shifts: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert Toolkit-Ops matrix output to compact extended inputs. + + Toolkit-Ops returns neighbors as a dense matrix over flattened atoms: + ``neighbor_matrix[dst_global, slot] = src_global`` and + ``shifts[dst_global, slot] = (sx, sy, sz)``. Here ``dst_global`` and + ``src_global`` are indices in the concatenated ``nf * nloc`` input. + + DeePMD lower paths use a different contract: ``nlist`` stores indices into + ``extended_coord``. Local atoms occupy ``[0, nloc)`` in each frame, while + shifted PBC images must be appended as ghost atoms. This conversion builds + the minimal ghost set by materializing each unique + ``(frame, src_local, shift)`` once, then redirects all shifted nlist entries + to the corresponding compact ghost index. + """ + nf = coord.shape[0] + total_atoms, max_neighbors = neighbor_matrix.shape + device = coord.device + dtype = coord.dtype + local_mapping = torch.arange(nloc, dtype=torch.long, device=device) + local_mapping = local_mapping.unsqueeze(0).expand(nf, -1) + nlist = torch.full((nf, nloc, max_neighbors), -1, dtype=torch.long, device=device) + + # === Step 1. Flatten valid Toolkit-Ops matrix slots === + # `edge_idx` indexes the flattened matrix layout `(total_atoms, max_neighbors)`. + # This avoids constructing a full repeated destination tensor. + slot = torch.arange(max_neighbors, dtype=torch.long, device=device).expand( + total_atoms, max_neighbors + ) + valid = (slot < num_neighbors.unsqueeze(1)).reshape(-1) + edge_idx = torch.nonzero(valid, as_tuple=False).flatten() + if edge_idx.numel() == 0: + return coord, atype, local_mapping, nlist + + # Decode flattened edge slots: + # dst : flattened center atom, in [0, nf * nloc) + # src : flattened neighbor atom returned by Toolkit-Ops + # frame_idx : batch frame/system containing both dst and src + # center_idx : local center atom index inside the frame + # src_local : local neighbor atom index before applying the PBC shift + dst = edge_idx // max_neighbors + src = neighbor_matrix.reshape(-1).index_select(0, edge_idx).to(dtype=torch.long) + shift = shifts.reshape(-1, 3).index_select(0, edge_idx).to(dtype=torch.long) + src_local = src % nloc + frame_idx = dst // nloc + center_idx = dst % nloc + slot_idx = edge_idx % max_neighbors + zero_shift = torch.all(shift == 0, dim=1) + + # === Step 2. Direct neighbors keep their local extended indices === + # Zero-shift neighbors already live in the leading local block of + # `extended_coord`, so their DeePMD nlist value is simply `src_local`. + direct_edge_idx = torch.nonzero(zero_shift, as_tuple=False).flatten() + nlist[ + frame_idx.index_select(0, direct_edge_idx), + center_idx.index_select(0, direct_edge_idx), + slot_idx.index_select(0, direct_edge_idx), + ] = src_local.index_select(0, direct_edge_idx) + + shifted_edge_idx = torch.nonzero(~zero_shift, as_tuple=False).flatten() + if shifted_edge_idx.numel() == 0: + return coord, atype, local_mapping, nlist + + # === Step 3. Materialize each unique shifted atom once per frame === + # A shifted source may appear in many center atoms' neighbor slots. Dedup by + # `(frame, src_local, shift)` so all such slots share one compact ghost atom. + ghost_keys = torch.cat( + [ + frame_idx.index_select(0, shifted_edge_idx).unsqueeze(1), + src_local.index_select(0, shifted_edge_idx).unsqueeze(1), + shift.index_select(0, shifted_edge_idx), + ], + dim=1, + ) + unique_keys, inverse = torch.unique(ghost_keys, dim=0, return_inverse=True) + ghost_frame = unique_keys[:, 0].to(dtype=torch.long) + ghost_src = unique_keys[:, 1].to(dtype=torch.long) + ghost_shift = unique_keys[:, 2:].to(dtype=dtype) + + # Assign per-frame compact ghost indices. `ghost_rank` is the offset within + # a frame's ghost block, so the final extended index is `nloc + ghost_rank`. + # The `.item()` sync is only used to size the padded dense output. + ghost_count = torch.bincount(ghost_frame, minlength=nf) + max_extra = int(ghost_count.max().item()) + order = torch.argsort(ghost_frame) + sorted_frame = ghost_frame.index_select(0, order) + frame_start = torch.cumsum(ghost_count, dim=0) - ghost_count + sorted_rank = torch.arange( + unique_keys.shape[0], dtype=torch.long, device=device + ) - frame_start.index_select(0, sorted_frame) + ghost_rank = torch.empty_like(sorted_rank) + ghost_rank[order] = sorted_rank + ghost_index = nloc + ghost_rank + + extended_coord = torch.zeros((nf, nloc + max_extra, 3), dtype=dtype, device=device) + extended_atype = torch.full( + (nf, nloc + max_extra), -1, dtype=atype.dtype, device=device + ) + mapping = torch.zeros((nf, nloc + max_extra), dtype=torch.long, device=device) + extended_coord[:, :nloc] = coord + extended_atype[:, :nloc] = atype + mapping[:, :nloc] = local_mapping + + # Convert integer cell shifts to Cartesian ghost coordinates and record the + # extended-to-local mapping used later to scatter forces/virials back. + shift_cart = torch.bmm( + ghost_shift.unsqueeze(1), cell.index_select(0, ghost_frame) + ).squeeze(1) + extended_coord[ghost_frame, ghost_index] = ( + coord[ghost_frame, ghost_src] + shift_cart + ) + extended_atype[ghost_frame, ghost_index] = atype[ghost_frame, ghost_src] + mapping[ghost_frame, ghost_index] = ghost_src + + # Redirect shifted neighbor slots to their compact ghost indices. `inverse` + # maps each shifted edge's key back to its row in `unique_keys`. + shifted_nlist_values = ghost_index.index_select(0, inverse) + shifted_frames = frame_idx.index_select(0, shifted_edge_idx) + shifted_centers = center_idx.index_select(0, shifted_edge_idx) + shifted_slots = slot_idx.index_select(0, shifted_edge_idx) + nlist[shifted_frames, shifted_centers, shifted_slots] = shifted_nlist_values + return extended_coord, extended_atype, mapping, nlist diff --git a/source/tests/pt/model/test_nlist_backend.py b/source/tests/pt/model/test_nlist_backend.py index ffc4b3e64f..a09127aa31 100644 --- a/source/tests/pt/model/test_nlist_backend.py +++ b/source/tests/pt/model/test_nlist_backend.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""``nlist_backend`` dispatch + vesin/native equivalence for the pt backend. +"""``nlist_backend`` dispatch + O(N) strategy equivalence for the pt backend. The pt model is reconstructed eagerly in ``DeepEval`` and evaluated via -``forward_common_lower`` when the O(N) vesin neighbor list is selected (the -exported TorchScript graph is untouched). native and vesin must give identical -results, and the ``nlist_backend`` choice must dispatch / validate correctly. +``forward_common_lower`` when an O(N) neighbor-list strategy is selected. Each +strategy (``vesin``, ``nv``) must give results identical to the native dense +builder, and the ``nlist_backend`` choice must dispatch / validate correctly. +Strategy tests are skipped when the backend (or, for ``nv``, a CUDA device) is +unavailable. """ import copy @@ -22,13 +24,27 @@ from deepmd.pt.train.wrapper import ( ModelWrapper, ) +from deepmd.pt.utils.nv_nlist import ( + NvNeighborList, + is_nv_available, +) from deepmd.pt_expt.utils.vesin_neighbor_list import ( + VesinNeighborList, is_vesin_torch_available, ) -pytestmark = pytest.mark.skipif( - not is_vesin_torch_available(), reason="vesin.torch is not installed" -) +# Each O(N) strategy: (backend name, builder class, availability skip mark). +_BACKEND_MARKS = { + "vesin": pytest.mark.skipif( + not is_vesin_torch_available(), reason="vesin.torch is not installed" + ), + "nv": pytest.mark.skipif( + not (is_nv_available() and torch.cuda.is_available()), + reason="nvalchemiops CUDA neighbor list unavailable", + ), +} +_BUILDER_CLS = {"vesin": VesinNeighborList, "nv": NvNeighborList} +STRATEGIES = [pytest.param(name, marks=mark) for name, mark in _BACKEND_MARKS.items()] TYPE_MAP = ["O", "H", "B"] @@ -79,6 +95,7 @@ def _save_pt(md_dict: dict, path: str) -> None: def _system(): + """Single periodic frame (every strategy supports a periodic box).""" rng = np.random.default_rng(20240604) coords = (rng.random((1, 8, 3)) * 6.0).astype(np.float64) atype = np.array([0, 0, 1, 1, 2, 0, 1, 2], dtype=np.int64) @@ -88,7 +105,7 @@ def _system(): def _multiframe_system(nframes: int = 3): """Frames with different box sizes -> different per-frame ghost counts, - exercising the vesin builder's pad-to-common-nall + stack path. + exercising the builder's pad-to-common-nall + stack path. """ rng = np.random.default_rng(20240604) atype = np.array([0, 0, 1, 1, 2, 0, 1, 2], dtype=np.int64) @@ -111,46 +128,78 @@ def pt_files(tmp_path_factory): return files -def test_default_is_auto(pt_files) -> None: - # vesin is available (module skip guard), non-spin/non-hessian -> auto picks it - assert DeepPot(pt_files["se_e2_a"]).deep_eval._use_vesin is True +def _assert_eval_close(dp_ref, dp_test, coords, cells, atype, msg: str) -> None: + ref = dp_ref.eval(coords, cells, atype, atomic=True) + out = dp_test.eval(coords, cells, atype, atomic=True) + for a, b, label in zip(ref, out, ["e", "f", "v", "ae", "av"], strict=True): + np.testing.assert_allclose(a, b, rtol=1e-9, atol=1e-9, err_msg=f"{msg} {label}") -def test_native_disables_vesin(pt_files) -> None: - dp = DeepPot(pt_files["se_e2_a"], nlist_backend="native") - assert dp.deep_eval._use_vesin is False +# --- dispatch / selection --------------------------------------------------- -def test_invalid_raises(pt_files) -> None: +def test_invalid_backend_raises(pt_files) -> None: with pytest.raises(ValueError): DeepPot(pt_files["se_e2_a"], nlist_backend="bogus") -@pytest.mark.parametrize("name", list(ALL_MODELS)) # descriptor family -@pytest.mark.parametrize("periodic", [False, True]) # non-PBC vs PBC -def test_vesin_matches_native(pt_files, name: str, periodic: bool) -> None: - """Vesin and native give identical energy/force/virial/atomic-virial.""" +def test_native_uses_no_strategy(pt_files) -> None: + dp = DeepPot(pt_files["se_e2_a"], nlist_backend="native") + assert dp.deep_eval._nlist_builder is None + + +@pytest.mark.parametrize("backend", STRATEGIES) +def test_explicit_backend_selects_builder(pt_files, backend: str) -> None: + dp = DeepPot(pt_files["se_e2_a"], nlist_backend=backend) + assert isinstance(dp.deep_eval._nlist_builder, _BUILDER_CLS[backend]) + + +@_BACKEND_MARKS["vesin"] +def test_auto_prefers_vesin(pt_files) -> None: + # auto picks the first available O(N) builder; vesin is preferred. + builder = DeepPot(pt_files["se_e2_a"]).deep_eval._nlist_builder + assert isinstance(builder, VesinNeighborList) + + +def test_self_built_model_forces_native(pt_files, monkeypatch) -> None: + # A model reporting use_self_built_nlist()=True keeps the native path and + # ignores the requested backend (without even validating its name). + deep_eval = DeepPot(pt_files["se_e2_a"], nlist_backend="native").deep_eval + inner = deep_eval.dp.model["Default"] + monkeypatch.setattr(inner, "use_self_built_nlist", lambda: True, raising=False) + for backend in ("auto", "vesin", "nv", "bogus"): + deep_eval._setup_nlist_backend(backend) + assert deep_eval._nlist_builder is None + + +# --- equivalence with the native dense builder ------------------------------ + + +@pytest.mark.parametrize("name", list(ALL_MODELS)) +@pytest.mark.parametrize("backend", STRATEGIES) +def test_strategy_matches_native(pt_files, backend: str, name: str) -> None: + """Each strategy matches native on a periodic single-frame system.""" coords, atype, box = _system() - cells = box if periodic else None dp_native = DeepPot(pt_files[name], nlist_backend="native") - dp_vesin = DeepPot(pt_files[name], nlist_backend="vesin") - ref = dp_native.eval(coords, cells, atype, atomic=True) - out = dp_vesin.eval(coords, cells, atype, atomic=True) - for a, b, label in zip(ref, out, ["e", "f", "v", "ae", "av"], strict=True): - np.testing.assert_allclose( - a, b, rtol=1e-9, atol=1e-9, err_msg=f"{name} {label}" - ) + dp_strat = DeepPot(pt_files[name], nlist_backend=backend) + _assert_eval_close(dp_native, dp_strat, coords, box, atype, f"{name} {backend}") -@pytest.mark.parametrize("name", list(ALL_MODELS)) # descriptor family -def test_vesin_matches_native_multiframe(pt_files, name: str) -> None: - """Multi-frame eval (frames with differing ghost counts) matches native.""" +@pytest.mark.parametrize("name", list(ALL_MODELS)) +@pytest.mark.parametrize("backend", STRATEGIES) +def test_strategy_matches_native_multiframe(pt_files, backend: str, name: str) -> None: + """Each strategy matches native across frames with differing ghost counts.""" coords, atype, box = _multiframe_system() dp_native = DeepPot(pt_files[name], nlist_backend="native") + dp_strat = DeepPot(pt_files[name], nlist_backend=backend) + _assert_eval_close(dp_native, dp_strat, coords, box, atype, f"{name} {backend} mf") + + +@_BACKEND_MARKS["vesin"] +@pytest.mark.parametrize("name", list(ALL_MODELS)) +def test_vesin_matches_native_nonperiodic(pt_files, name: str) -> None: + """Vesin also supports non-periodic systems (nv requires a periodic box).""" + coords, atype, _ = _system() + dp_native = DeepPot(pt_files[name], nlist_backend="native") dp_vesin = DeepPot(pt_files[name], nlist_backend="vesin") - ref = dp_native.eval(coords, box, atype, atomic=True) - out = dp_vesin.eval(coords, box, atype, atomic=True) - for a, b, label in zip(ref, out, ["e", "f", "v", "ae", "av"], strict=True): - np.testing.assert_allclose( - a, b, rtol=1e-9, atol=1e-9, err_msg=f"{name} {label}" - ) + _assert_eval_close(dp_native, dp_vesin, coords, None, atype, f"{name} vesin nopbc") diff --git a/source/tests/pt/model/test_nv_nlist.py b/source/tests/pt/model/test_nv_nlist.py new file mode 100644 index 0000000000..607f0702dd --- /dev/null +++ b/source/tests/pt/model/test_nv_nlist.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for the ``NvNeighborList`` builder. + +These cover the builder paths the DeepEval end-to-end equivalence test +(``test_nlist_backend.py``) cannot reach with its small ``batch_naive`` systems: +the ``batch_cell_list`` method and the over-capacity distance-trim path, plus the +periodic-box requirement. Built neighbor lists are compared against the native +dense builder at the nlist level (edge topology + geometry). +""" + +import unittest +from unittest import ( + mock, +) + +import torch + +from deepmd.pt.utils import ( + env, + nv_nlist, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.pt.utils.nv_nlist import ( + NvNeighborList, +) + + +def _edge_topology_from_extended( + mapping: torch.Tensor, + nlist: torch.Tensor, +) -> torch.Tensor: + """Convert an extended-coordinate nlist to sorted local edge topology rows.""" + nf, nloc, nsel = nlist.shape + nall = mapping.shape[1] + device = nlist.device + dst = torch.arange(nf * nloc, dtype=torch.long, device=device).repeat_interleave( + nsel + ) + frame = dst // nloc + center = dst % nloc + neighbor = nlist.reshape(-1).to(dtype=torch.long) + valid = neighbor >= 0 + neighbor_safe = torch.where(valid, neighbor, torch.zeros_like(neighbor)) + src_local = mapping.reshape(-1).index_select(0, frame * nall + neighbor_safe) + valid = valid & (src_local >= 0) & (src_local < nloc) + rows = torch.stack([frame[valid], src_local[valid], center[valid]], dim=1) + key = rows[:, 0] * nloc * nloc + rows[:, 1] * nloc + rows[:, 2] + return rows.index_select(0, torch.argsort(key)) + + +def _edge_geometry_from_extended( + extended_coord: torch.Tensor, + mapping: torch.Tensor, + nlist: torch.Tensor, +) -> torch.Tensor: + """Convert an extended-coordinate nlist to sorted edge-vector rows.""" + nf, nloc, nsel = nlist.shape + nall = extended_coord.shape[1] + device = extended_coord.device + dst = torch.arange(nf * nloc, dtype=torch.long, device=device).repeat_interleave( + nsel + ) + frame = dst // nloc + center = dst % nloc + neighbor = nlist.reshape(-1).to(dtype=torch.long) + valid = neighbor >= 0 + neighbor_safe = torch.where(valid, neighbor, torch.zeros_like(neighbor)) + src_local = mapping.reshape(-1).index_select(0, frame * nall + neighbor_safe) + src_valid = (src_local >= 0) & (src_local < nloc) + + coord_flat = extended_coord.reshape(nf * nall, 3) + src_coord = coord_flat.index_select(0, frame * nall + neighbor_safe) + dst_coord = coord_flat.index_select(0, frame * nall + center) + edge_vec = src_coord - dst_coord + valid = valid & src_valid & (torch.sum(edge_vec * edge_vec, dim=-1) > 1.0e-10) + topo = torch.stack([frame[valid], src_local[valid], center[valid]], dim=1) + key = topo[:, 0] * nloc * nloc + topo[:, 1] * nloc + topo[:, 2] + return edge_vec[valid].index_select(0, torch.argsort(key)).to(torch.float64) + + +def _assert_extended_atype_matches_mapping( + test_case: unittest.TestCase, + local_atype: torch.Tensor, + extended_atype: torch.Tensor, + mapping: torch.Tensor, +) -> None: + """Check that each real extended atom keeps the type of its mapped local atom.""" + nf, nall = extended_atype.shape + nloc = local_atype.shape[1] + frame = torch.arange(nf, dtype=torch.long, device=extended_atype.device) + frame = frame.unsqueeze(1).expand(nf, nall) + valid = extended_atype >= 0 + mapped = mapping.clamp(min=0, max=nloc - 1) + expected = local_atype[frame, mapped] + test_case.assertTrue(torch.equal(extended_atype[valid], expected[valid])) + + +@unittest.skipUnless( + torch.cuda.is_available() and nv_nlist.is_nv_available(), + "NVIDIA Toolkit-Ops CUDA path is unavailable", +) +class TestNVNList(unittest.TestCase): + def setUp(self) -> None: + self.device = env.DEVICE + + def _build_case( + self, nframes: int + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + coord_one = torch.tensor( + [ + [0.2, 0.2, 0.2], + [7.7, 0.2, 0.2], + [0.2, 7.6, 0.2], + [3.8, 3.9, 4.1], + ], + dtype=torch.float64, + device=self.device, + ) + coord = coord_one.unsqueeze(0).repeat(nframes, 1, 1) + if nframes > 1: + coord[1] = torch.tensor( + [ + [2.0, 2.0, 2.0], + [4.0, 2.0, 2.0], + [2.0, 4.0, 2.0], + [4.0, 4.0, 4.0], + ], + dtype=coord.dtype, + device=coord.device, + ) + atype = torch.tensor([[0, 1, 0, 1]], dtype=torch.int32, device=self.device) + atype = atype.repeat(nframes, 1) + box = torch.eye(3, dtype=torch.float64, device=self.device).reshape(1, 9) * 8.0 + box = box.repeat(nframes, 1) + return coord, atype, box + + def _build_overfull_case(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + coord = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.9, 0.0, 0.0], + [2.1, 0.0, 0.0], + [3.4, 0.0, 0.0], + [1.0, 1.0, 0.0], + [2.2, 1.1, 0.0], + ], + dtype=torch.float64, + device=self.device, + ).unsqueeze(0) + atype = torch.tensor( + [[0, 1, 0, 1, 0, 1]], dtype=torch.int32, device=self.device + ) + box = torch.eye(3, dtype=torch.float64, device=self.device).reshape(1, 9) * 20.0 + return coord, atype, box + + def _assert_nv_matches_native( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor, + rcut: float, + sel: list[int], + force_cell_list: bool = False, + ) -> None: + # native: (extended_coord, extended_atype, mapping, nlist) + native = extend_input_and_build_neighbor_list( + coord, + atype, + rcut, + sel, + mixed_types=True, + box=box, + ) + # NeighborList strategy: (extended_coord, extended_atype, nlist, mapping) + builder = NvNeighborList() + if force_cell_list: + with mock.patch.object(nv_nlist, "NV_CELL_LIST_THRESHOLD", 1): + nv = builder.build(coord, atype, box, rcut, sel) + else: + nv = builder.build(coord, atype, box, rcut, sel) + native_coord, _, native_mapping, native_nlist = native + nv_coord, nv_atype, nv_nlist_out, nv_mapping = nv + # The strategy trims to sum(sel) itself, so the width is fixed. + self.assertEqual(nv_nlist_out.shape[-1], sum(sel)) + self.assertTrue( + torch.equal( + _edge_topology_from_extended(native_mapping, native_nlist), + _edge_topology_from_extended(nv_mapping, nv_nlist_out), + ) + ) + torch.testing.assert_close( + _edge_geometry_from_extended(native_coord, native_mapping, native_nlist), + _edge_geometry_from_extended(nv_coord, nv_mapping, nv_nlist_out), + atol=1.0e-10, + rtol=1.0e-10, + ) + _assert_extended_atype_matches_mapping(self, atype, nv_atype, nv_mapping) + + def test_cell_list_matches_native(self) -> None: + """The ``batch_cell_list`` method (forced via the threshold) matches the + native builder over a multi-frame periodic batch. End-to-end systems are + always below the threshold and take ``batch_naive``. + """ + coord, atype, box = self._build_case(2) + self._assert_nv_matches_native( + coord=coord, atype=atype, box=box, rcut=3.0, sel=[8], force_cell_list=True + ) + + def test_overfull_truncates_to_sel(self) -> None: + """A center with more real neighbors than ``sum(sel)`` is distance-sorted + and trimmed to the nearest ``sum(sel)`` -- the path behind the + compiled-graph width bug, which end-to-end systems never reach. + """ + coord, atype, box = self._build_overfull_case() + self._assert_nv_matches_native( + coord=coord, atype=atype, box=box, rcut=4.0, sel=[2], force_cell_list=False + ) + + def test_requires_periodic_box(self) -> None: + """The cell list needs a periodic box; ``box=None`` is rejected.""" + coord, atype, _ = self._build_case(1) + with self.assertRaises(ValueError): + NvNeighborList().build(coord, atype, None, 3.0, [8]) From c7ebcb74dcbdf50cd2cd8d8e196d5cb22c871a54 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sat, 6 Jun 2026 14:37:06 +0800 Subject: [PATCH 02/18] feat(dpa4): multiple updates for DPA4/SeZM feat(sezm): add activation checkpoint option feat(sezm): add cross node tensor product feat(sezm): add message node tensor product feat(sezm): seperate node lmax with edge lmax feat(sezm): add custom kernel for lmax=5-10 feat(sezm): add so3 grid projection refactor: change default values feat(sezm): tf32 infer --- deepmd/pt/entrypoints/freeze_pt2.py | 8 +- .../model/atomic_model/sezm_atomic_model.py | 5 +- deepmd/pt/model/descriptor/sezm.py | 273 ++++- .../pt/model/descriptor/sezm_nn/__init__.py | 28 +- .../pt/model/descriptor/sezm_nn/activation.py | 547 +-------- deepmd/pt/model/descriptor/sezm_nn/block.py | 192 +++- .../pt/model/descriptor/sezm_nn/embedding.py | 84 +- deepmd/pt/model/descriptor/sezm_nn/ffn.py | 215 ++-- .../pt/model/descriptor/sezm_nn/grid_net.py | 757 ++++++++++++ .../pt/model/descriptor/sezm_nn/indexing.py | 38 + .../pt/model/descriptor/sezm_nn/projection.py | 594 ++++++++++ deepmd/pt/model/descriptor/sezm_nn/so2.py | 198 +++- deepmd/pt/model/descriptor/sezm_nn/wignerd.py | 562 +++++++-- deepmd/pt/model/model/__init__.py | 2 + deepmd/pt/model/model/sezm_model.py | 784 ++++++++----- deepmd/pt/train/training.py | 20 +- deepmd/utils/argcheck.py | 200 +++- doc/model/dpa4.md | 342 ++++-- examples/water/dpa4/README.md | 4 +- examples/water/dpa4/input-spin.json | 42 +- examples/water/dpa4/input-zbl.json | 42 +- examples/water/dpa4/input.json | 56 +- examples/water/dpa4/input_dens.json | 42 +- examples/water/dpa4/input_multitask.json | 42 +- examples/water/dpa4/lora_ft.json | 42 +- source/tests/pt/model/test_descriptor_sezm.py | 125 +- .../test_descriptor_sezm_grid_projection.py | 1023 +++++++++++++++++ .../test_descriptor_sezm_s2_equivariance.py | 384 ------- source/tests/pt/model/test_sezm_model.py | 16 +- 29 files changed, 4727 insertions(+), 1940 deletions(-) create mode 100644 deepmd/pt/model/descriptor/sezm_nn/grid_net.py create mode 100644 deepmd/pt/model/descriptor/sezm_nn/projection.py create mode 100644 source/tests/pt/model/test_descriptor_sezm_grid_projection.py delete mode 100644 source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py diff --git a/deepmd/pt/entrypoints/freeze_pt2.py b/deepmd/pt/entrypoints/freeze_pt2.py index bbbede110e..c1c36b4fb7 100644 --- a/deepmd/pt/entrypoints/freeze_pt2.py +++ b/deepmd/pt/entrypoints/freeze_pt2.py @@ -488,6 +488,7 @@ def freeze_sezm_to_pt2( from torch._inductor import ( aoti_compile_and_package, ) + from torch._inductor import config as inductor_config target_device = device if device is not None else DEVICE @@ -587,7 +588,12 @@ def freeze_sezm_to_pt2( exported = move_to_device_pass(exported, target_device) out_path_str = str(out_path) - aoti_compile_and_package(exported, package_path=out_path_str) + # Match the runtime eval compile path's Inductor option: triton.max_tiles=1 + # keeps pointwise grids 1D so the data-dependent compact-edge axis stays on + # Triton's x grid (limit 2**31); the default tiling places it on the y/z + # grid (limit 65535), which overflows for large systems. + with inductor_config.patch({"triton.max_tiles": 1}): + aoti_compile_and_package(exported, package_path=out_path_str) metadata = _collect_metadata(model, output_keys=output_keys, is_spin=is_spin) with zipfile.ZipFile(out_path_str, "a") as zf: diff --git a/deepmd/pt/model/atomic_model/sezm_atomic_model.py b/deepmd/pt/model/atomic_model/sezm_atomic_model.py index cfc1de910b..e96cd6b761 100644 --- a/deepmd/pt/model/atomic_model/sezm_atomic_model.py +++ b/deepmd/pt/model/atomic_model/sezm_atomic_model.py @@ -727,8 +727,9 @@ def _build_dens_fitting_kwargs(self) -> dict[str, Any]: """Reconstruct SeZM `dens`-head kwargs from energy head and descriptor.""" descriptor = self.descriptor kwargs = self._build_ener_fitting_kwargs() - kwargs["condition_lmax"] = int(descriptor.l_schedule[0]) - kwargs["latent_lmax"] = int(descriptor.l_schedule[-1]) + node_l_schedule = getattr(descriptor, "node_l_schedule", descriptor.l_schedule) + kwargs["condition_lmax"] = int(node_l_schedule[0]) + kwargs["latent_lmax"] = int(node_l_schedule[-1]) kwargs["channels"] = int(descriptor.channels) return kwargs diff --git a/deepmd/pt/model/descriptor/sezm.py b/deepmd/pt/model/descriptor/sezm.py index f7a3440749..257985a805 100644 --- a/deepmd/pt/model/descriptor/sezm.py +++ b/deepmd/pt/model/descriptor/sezm.py @@ -19,8 +19,8 @@ Layout notes ------------ -- Node-level backbone features use contiguous `(N, D, 1, C)` where - `D=(lmax+1)^2` and `C=channels`. +- Node-level backbone features use contiguous `(N, D_node, 1, C)` where + `D_node=(l_schedule[i]+extra_node_l+1)^2` and `C=channels`. - The singleton focus axis is kept only to reuse the existing equivariant operators; real multi-focus structure lives strictly inside `SO2Convolution`. - Edge-level SO(2) internal operators keep m-major reduced layout @@ -91,6 +91,7 @@ WignerDCalculator, build_edge_cache, build_edge_cache_from_edges, + build_edge_quaternion, edge_cache_to_dtype, fold_lora_state_dict_keys, get_promoted_dtype, @@ -98,6 +99,7 @@ has_lora, np_safe, nvtx_range, + safe_norm, safe_numpy_to_tensor, ) @@ -153,12 +155,12 @@ class DescrptSeZM(BaseDescriptor, nn.Module): Number of radial basis functions. radial_mlp Hidden layer sizes for radial networks. An output layer of size - `(l_schedule[0]+1)*channels` will be automatically appended. + `(l_schedule[0]+extra_node_l+1)*channels` will be automatically appended. use_env_seed If True, seed the initial node state with local-environment information: apply environment matrix FiLM conditioning on l=0 features using 4D `[s, s*r_hat]` representation, and enable the non-scalar geometric - initial embedding when `l_schedule[0] > 0`. If False, the initial state + initial embedding when `l_schedule[0] + extra_node_l > 0`. If False, the initial state contains only atom-local scalar features before message passing. FiLM deltas are normalized and scaled with learnable strengths initialized to small values. Internal dimensions are derived from `channels`: @@ -179,10 +181,19 @@ class DescrptSeZM(BaseDescriptor, nn.Module): mmax Maximum SO(2) order (|m|), only used when `m_schedule` is None. If None, defaults to the per-block `lmax` (i.e. `m_schedule = l_schedule`). + kmax + Maximum Wigner-D frame order (|k|) used by SO(3) grid nets. The frame set + is built as ``[0, -1, 1, ..., -kmax, kmax]``. ``kmax=0`` recovers the + S2-like k=0 slice, while ``kmax=1`` is the default low-cost setting that + opens odd/antisymmetric coupling paths. m_schedule Schedule of mmax per block, e.g. [2, 2, 1, 0]. Must satisfy `m_schedule[i] <= l_schedule[i]` for every block. A non-increasing schedule is recommended but not required. If set, `mmax` will be ignored. + extra_node_l + Extra node representation degree above each message-passing degree. + The node degree of block `i` is `l_schedule[i] + extra_node_l`, while + SO(2) message passing still uses `l_schedule[i]`. n_blocks Number of blocks (only used when `l_schedule` is None). so2_norm @@ -231,8 +242,21 @@ class DescrptSeZM(BaseDescriptor, nn.Module): effective GLU setting: ``4 * channels`` without GLU, ``(8 / 3) * channels`` with GLU, then round up to a multiple of 32. grid_mlp - If True, use the optional grid-MLP structure for the block-internal FFN - units. The final scalar output head is unchanged. + Either one boolean applied to every grid path, or three booleans + ``[node_wise, message_node, ffn]`` selecting the polynomial point-wise + grid MLP operation per grid path. On any path whose ``grid_branch`` + entry is positive it is overridden by branch mixing, and it has no + effect on the final ``l=0`` output head. + grid_branch + Either one non-negative integer applied to every grid path, or three + integers ``[node_wise, message_node, ffn]`` setting the number of + scalar-routed polynomial product branches per grid path. ``0`` disables + branch mixing on that path; positive values select branch mixing and + take precedence over ``grid_mlp``. Branch weights are computed from + ``l=0`` scalar features only, while each branch is a quadratic product + of channel-mixed grid fields. The ``node_wise`` and ``message_node`` + entries control the SO(2) convolution cross-grid paths, and the ``ffn`` + entry controls the block-internal FFN grid path. ffn_blocks Number of FFN subblocks per interaction block. sandwich_norm @@ -275,17 +299,33 @@ class DescrptSeZM(BaseDescriptor, nn.Module): ``activation_function="silu"``. ``ffn_enabled=True`` makes the block-internal FFN path use ``activation_function="silu"`` and ``glu_activation=True``. - S2-grid resolutions are resolved automatically per block. The e3nn - product grid uses ``[2 * mmax + 4, ceil_even(3 * lmax + 2)]`` in the - SO(2) branch, and the FFN branch lifts it to a square + S2-grid resolutions are resolved automatically per block. The + tensor-product grid uses ``[2 * mmax + 4, ceil_even(3 * lmax + 2)]`` + in the SO(2) branch, and the FFN branch lifts it to a square ``[max(R_phi, R_theta), max(R_phi, R_theta)]`` grid. Lebedev branches use the smallest packaged rule with precision at least ``3 * lmax``. The final ``l=0`` output FFN is unchanged. + ffn_so3_grid + If True, use the SO(3) Wigner-D grid in the block-internal FFN. This + option takes precedence over the FFN grid path and ignores + ``s2_activation[1]``. The final ``l=0`` output FFN is unchanged. + node_wise_s2 + If True, add an edge-local S2 product branch between source and + destination node features inside the SO(2) convolution. + node_wise_so3 + If True, use the corresponding edge-local SO(3) Wigner-D grid-net branch. + The source side is the query and the destination side is the context. + message_node_s2 + If True, add a post-aggregation S2 product branch between hidden messages + and destination node features before the SO(2) output projection. + message_node_so3 + If True, use the corresponding post-aggregation SO(3) Wigner-D grid-net + branch. The message is the query and the node state is the context. lebedev_quadrature Either one boolean applied to both S2 branches, or two booleans ``[so2_enabled, ffn_enabled]`` aligned with ``s2_activation``. If - enabled for a branch, that branch uses Lebedev quadrature instead of - the e3nn product grid in its S2 projector. + enabled for a branch, that branch uses packaged Lebedev quadrature + instead of the tensor-product sphere grid in its S2 projector. activation_function Base activation function for helper MLPs, the SO(2) gated activation path, and the final ``l=0`` output FFN. @@ -356,7 +396,9 @@ def __init__( lmax: int = 3, l_schedule: list[int] | None = None, mmax: int | None = 1, + kmax: int = 1, m_schedule: list[int] | None = None, + extra_node_l: int = 0, n_blocks: int = 3, so2_norm: bool = False, so2_layers: int = 4, @@ -370,7 +412,8 @@ def __init__( atten_v_proj: bool = False, atten_o_proj: bool = False, ffn_neurons: int = 0, - grid_mlp: bool = False, + grid_mlp: bool | list[bool] = False, + grid_branch: int | list[int] = 0, ffn_blocks: int = 1, sandwich_norm: list[bool] | None = None, mlp_bias: bool = False, @@ -378,6 +421,11 @@ def __init__( full_attn_res: str = "none", block_attn_res: str = "none", s2_activation: list[bool] | None = None, + ffn_so3_grid: bool = False, + node_wise_s2: bool = False, + node_wise_so3: bool = False, + message_node_s2: bool = False, + message_node_so3: bool = False, lebedev_quadrature: bool | list[bool] | None = True, activation_function: str = "silu", glu_activation: bool = True, @@ -460,8 +508,13 @@ def __init__( "`s2_activation` must be a list[bool] of length 2: [so2_activation, ffn_activation]" ) self.s2_activation = list(s2_activation) + self.ffn_so3_grid = bool(ffn_so3_grid) + self.node_wise_s2 = bool(node_wise_s2) + self.node_wise_so3 = bool(node_wise_so3) + self.message_node_s2 = bool(message_node_s2) + self.message_node_so3 = bool(message_node_so3) if lebedev_quadrature is None: - lebedev_quadrature = [False, False] + lebedev_quadrature = [True, True] elif isinstance(lebedev_quadrature, bool): lebedev_quadrature = [lebedev_quadrature, lebedev_quadrature] if not isinstance(lebedev_quadrature, list) or len(lebedev_quadrature) != 2: @@ -478,7 +531,7 @@ def __init__( # === Split effective activation config by branch === self.so2_s2_activation = self.s2_activation[0] - self.ffn_s2_activation = self.s2_activation[1] + self.ffn_s2_activation = False if self.ffn_so3_grid else self.s2_activation[1] self.so2_lebedev_quadrature = self.lebedev_quadrature[0] self.ffn_lebedev_quadrature = self.lebedev_quadrature[1] self.so2_activation_function = ( @@ -488,7 +541,9 @@ def __init__( "silu" if self.ffn_s2_activation else self.activation_function ) self.ffn_glu_activation = ( - True if self.ffn_s2_activation else self.glu_activation + True + if (self.ffn_s2_activation or self.ffn_so3_grid) + else self.glu_activation ) self.out_activation_function = self.activation_function self.out_glu_activation = self.glu_activation @@ -568,7 +623,13 @@ def __init__( # === L/M schedules === self._init_lm_schedules(lmax, n_blocks, l_schedule, mmax, m_schedule) + self.kmax = int(kmax) + if self.kmax < 0: + raise ValueError("`kmax` must be non-negative") + if self.kmax > int(lmax): + raise ValueError("`kmax` must be <= `lmax`") self.ebed_dims = [get_so3_dim_of_lmax(l) for l in self.l_schedule] + self._init_node_l_schedules(extra_node_l) self.rad_sizes_per_block = [l + 1 for l in self.l_schedule] self.so2_norm = bool(so2_norm) @@ -595,7 +656,27 @@ def __init__( self.ffn_neurons, glu_activation=self.out_glu_activation, ) - self.grid_mlp = bool(grid_mlp) + self.grid_mlp = self._broadcast_grid_setting( + grid_mlp, + name="grid_mlp", + cast=bool, + ) + self.grid_branch = self._broadcast_grid_setting( + grid_branch, + name="grid_branch", + cast=int, + non_negative=True, + ) + ( + self.node_wise_grid_mlp, + self.message_node_grid_mlp, + self.ffn_grid_mlp, + ) = self.grid_mlp + ( + self.node_wise_grid_branch, + self.message_node_grid_branch, + self.ffn_grid_branch, + ) = self.grid_branch self.ffn_blocks = int(ffn_blocks) if self.ffn_blocks < 1: raise ValueError("`ffn_blocks` must be >= 1") @@ -722,10 +803,11 @@ def __init__( ) # === Shared radial embedding: RBF -> per-l radial features === - # Output dimension is (lmax+1)*channels, directly usable by GIE and SO2Conv. + # Output dimension follows the first node degree, directly usable by + # GIE and truncated for each SO2Conv block. # radial_mlp specifies hidden layer sizes; input/output layers are prepended/appended. # Use fp32+ precision (same as RBF output) for numerical stability. - radial_out_dim = (self.lmax + 1) * self.channels + radial_out_dim = (self.node_l_schedule[0] + 1) * self.channels radial_mlp_layers = [self.n_radial, *self.radial_mlp, radial_out_dim] self.radial_embedding = RadialMLP( radial_mlp_layers, @@ -746,21 +828,35 @@ def __init__( dtype=self.compute_dtype, ) - self.use_gie = self.use_env_seed and self.l_schedule[0] > 0 + self.use_gie = self.use_env_seed and self.node_l_schedule[0] > 0 if self.use_gie: self.gie = GeometricInitialEmbedding( - lmax=self.l_schedule[0], + lmax=self.node_l_schedule[0], channels=self.channels, dtype=self.compute_dtype, # force fp32+ ) + if self.extra_node_l > 0: + self.gie_zonal_wigner_calc: WignerDCalculator | None = ( + WignerDCalculator( + lmax=self.node_l_schedule[0], + eps=self.eps, + dtype=self.compute_dtype, + ) + ) + else: + self.gie_zonal_wigner_calc = None else: self.gie = None + self.gie_zonal_wigner_calc = None blocks: list[SeZMInteractionBlock] = [] - for block_idx, (l_b, m_b) in enumerate(zip(self.l_schedule, self.m_schedule)): + for block_idx, (l_b, node_l_b, m_b) in enumerate( + zip(self.l_schedule, self.node_l_schedule, self.m_schedule) + ): blocks.append( SeZMInteractionBlock( lmax=l_b, + node_lmax=node_l_b, mmax=m_b, channels=self.channels, n_focus=self.n_focus, @@ -771,13 +867,24 @@ def __init__( radial_so2_mode=self.radial_so2_mode, radial_so2_rank=self.radial_so2_rank, ffn_neurons=self.block_ffn_neurons, - grid_mlp=self.grid_mlp, + node_wise_grid_mlp=self.node_wise_grid_mlp, + node_wise_grid_branch=self.node_wise_grid_branch, + message_node_grid_mlp=self.message_node_grid_mlp, + message_node_grid_branch=self.message_node_grid_branch, + ffn_grid_mlp=self.ffn_grid_mlp, + ffn_grid_branch=self.ffn_grid_branch, ffn_blocks=self.ffn_blocks, layer_scale=self.layer_scale, full_attn_res=self.full_attn_res_mode, block_attn_res=self.block_attn_res_mode, so2_s2_activation=self.so2_s2_activation, + node_wise_s2=self.node_wise_s2, + node_wise_so3=self.node_wise_so3, + message_node_s2=self.message_node_s2, + message_node_so3=self.message_node_so3, ffn_s2_activation=self.ffn_s2_activation, + ffn_so3_grid=self.ffn_so3_grid, + kmax=self.kmax, so2_lebedev_quadrature=self.so2_lebedev_quadrature, ffn_lebedev_quadrature=self.ffn_lebedev_quadrature, n_atten_head=self.n_atten_head, @@ -914,7 +1021,7 @@ def forward( force_embedding Optional precomputed equivariant force embedding with shape ``(nf * nloc, D, 1, channels)``, where - ``D = (l_schedule[0] + 1) ** 2``. This tensor is added to the + ``D = (node_l_schedule[0] + 1) ** 2``. This tensor is added to the initial SO(3) backbone state before the interaction blocks. charge_spin Frame-level charge and spin conditions with shape (nf, 2). @@ -1017,25 +1124,26 @@ def forward( edge_envelope=self.edge_envelope, radial_basis=self.radial_basis, n_radial=self.radial_basis.n_radial, - random_gamma=self.random_gamma, + # Random local-Z roll is a training-only augmentation; + # the model is roll-equivariant, so inference fixes gamma. + random_gamma=self.random_gamma and self.training, wigner_calc=self.wigner_calc, use_geometry_rbf_triton=(self.use_triton and not self.training), ) - lmax_0 = self.l_schedule[0] - ebed_dim_0 = get_so3_dim_of_lmax(lmax_0) # (lmax+1)^2 + ebed_dim_0 = self.node_ebed_dims[0] # (node_lmax+1)^2 x0 = type_ebed # (N, C) x0_out = x0 # (N, C) # === Step 5. Compute radial features once (fp32+) === - # Shape: (E, (lmax+1)*C) -> (E, lmax+1, C) + # Shape: (E, (node_lmax+1)*C) -> (E, node_lmax+1, C) radial_feat = None with nvtx_range("radial_embedding"): if edge_cache.src.numel() > 0: radial_feat = rearrange( self.radial_embedding(edge_cache.edge_rbf), "E (L C) -> E L C", - L=self.lmax + 1, + L=self.node_l_schedule[0] + 1, C=self.channels, ) # (E, lmax+1, C) if self.version >= 1.1: @@ -1068,10 +1176,12 @@ def forward( with nvtx_range("gie"): if self.use_gie and radial_feat is not None: # GIE only needs l>=1, slice radial_feat[:, 1:, :] + zonal_coupling = self._build_gie_zonal_coupling(edge_cache) x = x + self.gie( n_nodes=n_nodes, edge_cache=edge_cache, radial_feat=radial_feat[:, 1:, :], + zonal_coupling=zonal_coupling, ).unsqueeze(2) # === Step 9. Fuse edge type features into radial features (fp32+) === @@ -1149,7 +1259,7 @@ def forward_with_edges( force_embedding Optional precomputed equivariant force embedding with shape ``(nf * nloc, D, 1, channels)``, where - ``D = (l_schedule[0] + 1) ** 2``. This tensor is added to the + ``D = (node_l_schedule[0] + 1) ** 2``. This tensor is added to the initial SO(3) backbone state before the interaction blocks. charge_spin Frame-level charge and spin conditions with shape (nf, 2). @@ -1198,12 +1308,13 @@ def forward_with_edges( radial_basis=self.radial_basis, has_exclude_types=bool(self.exclude_types), edge_type_keep_mask=self._edge_type_keep_mask, - random_gamma=self.random_gamma, + # Random local-Z roll is a training-only augmentation; + # the model is roll-equivariant, so inference fixes gamma. + random_gamma=self.random_gamma and self.training, wigner_calc=self.wigner_calc, ) - lmax_0 = self.l_schedule[0] - ebed_dim_0 = get_so3_dim_of_lmax(lmax_0) # (lmax+1)^2 + ebed_dim_0 = self.node_ebed_dims[0] # (node_lmax+1)^2 x0 = type_ebed # (N, C) x0_out = x0 # (N, C) @@ -1211,7 +1322,9 @@ def forward_with_edges( with nvtx_range("radial_embedding"): radial_feat_flat = self.radial_embedding(edge_cache.edge_rbf) radial_feat = radial_feat_flat.reshape( - radial_feat_flat.shape[0], self.lmax + 1, self.channels + radial_feat_flat.shape[0], + self.node_l_schedule[0] + 1, + self.channels, ) # (E, lmax+1, C) if self.version >= 1.1: radial_feat = radial_feat * edge_cache.edge_env.reshape(-1, 1, 1) @@ -1242,10 +1355,12 @@ def forward_with_edges( # === Step 7. Geometric Initial Embedding (fp32+) === with nvtx_range("gie"): if self.use_gie: + zonal_coupling = self._build_gie_zonal_coupling(edge_cache) x = x + self.gie( n_nodes=n_nodes, edge_cache=edge_cache, radial_feat=radial_feat[:, 1:, :], + zonal_coupling=zonal_coupling, ).unsqueeze(2) # === Step 8. Fuse edge type features into radial features (fp32+) === @@ -1306,7 +1421,7 @@ def _forward_blocks( if not self.use_full_attn_res and not self.use_block_attn_res: # === Fast path without descriptor-level attention residuals === for i, block in enumerate(self.blocks): - x = x[:, : self.ebed_dims[i], :, :] + x = x[:, : self.node_ebed_dims[i], :, :] blk_radial = radial_feat_per_block[i] with nvtx_range(f"block_{i}"): x, _, _, _ = block(x, edge_cache, blk_radial) @@ -1324,7 +1439,7 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: # === Step 2. Run each block with selective unit-history aggregation === for i, block in enumerate(self.blocks): - current_dim = self.ebed_dims[i] + current_dim = self.node_ebed_dims[i] current_x = x[:, :current_dim, :, :] truncated_unit_history = [ source[:, :current_dim, :, :] for source in unit_history @@ -1342,7 +1457,7 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: x = block_output # === Step 3. Final aggregation over all completed unit representations === - final_dim = self.ebed_dims[-1] + final_dim = self.node_ebed_dims[-1] final_sources = [source[:, :final_dim, :, :] for source in unit_history] x = self.final_full_attn_res( sources=final_sources, @@ -1356,7 +1471,7 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: # === Step 2. Run each block with selective block-history aggregation === for i, block in enumerate(self.blocks): - current_dim = self.ebed_dims[i] + current_dim = self.node_ebed_dims[i] current_x = x[:, :current_dim, :, :] truncated_block_history = [ source[:, :current_dim, :, :] for source in block_history @@ -1373,7 +1488,7 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: x = block_output # === Step 3. Final aggregation over all completed block summaries === - final_dim = self.ebed_dims[-1] + final_dim = self.node_ebed_dims[-1] final_sources = [source[:, :final_dim, :, :] for source in block_history] x = self.final_block_attn_res( sources=final_sources, @@ -1382,6 +1497,41 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: ).to(dtype=self.dtype) return x + def _build_gie_zonal_coupling( + self, + edge_cache: EdgeFeatureCache, + ) -> torch.Tensor | None: + """ + Build node-level zonal coupling for GIE when node degrees exceed MP degrees. + + Returns + ------- + torch.Tensor or None + Coupling with shape ``(E, D_node - 1)`` when ``extra_node_l > 0``; + otherwise None, letting GIE gather from the MP Wigner-D cache. + """ + if self.gie_zonal_wigner_calc is None: + return None + mp_row_count = self.ebed_dims[0] - 1 + mp_row_index = self.gie.non_scalar_row_index[:mp_row_count] + mp_m0_col_index = self.gie.zonal_m0_col_index_for_row[:mp_row_count] + mp_coupling = edge_cache.Dt_full[ + :, + mp_row_index, + mp_m0_col_index, + ] + edge_len = safe_norm(edge_cache.edge_vec, self.eps) + edge_quat = build_edge_quaternion( + edge_cache.edge_vec, + edge_len=edge_len, + eps=self.eps, + ) + extra_coupling = self.gie_zonal_wigner_calc.forward_zonal( + edge_quat, + lmin=self.lmax + 1, + ) + return torch.cat([mp_coupling, extra_coupling], dim=1) + def _apply_charge_spin_embedding( self, type_ebed: torch.Tensor, @@ -1447,6 +1597,31 @@ def _edge_type_keep_mask( keep = type_mask.index_select(0, type_ij.to(dtype=torch.long)) return keep.to(dtype=torch.bool) + @staticmethod + def _broadcast_grid_setting( + value: bool | int | list[bool] | list[int], + *, + name: str, + cast: type, + non_negative: bool = False, + ) -> list: + """Normalize a grid-path setting to ``[node_wise, message_node, ffn]``. + + A scalar is broadcast to all three grid paths, while a length-three + list is validated element-wise. When ``non_negative`` is set, every + entry must be ``>= 0``. + """ + entries = list(value) if isinstance(value, list) else [value, value, value] + if len(entries) != 3: + raise ValueError( + f"`{name}` must be a {cast.__name__} or a list[{cast.__name__}] " + "of length 3: [node_wise, message_node, ffn]" + ) + normalized = [cast(entry) for entry in entries] + if non_negative and any(entry < 0 for entry in normalized): + raise ValueError(f"`{name}` entries must be non-negative") + return normalized + def _resolve_ffn_neurons( self, ffn_neurons: int, @@ -1517,6 +1692,20 @@ def _init_lm_schedules( self.mmax = int(self.m_schedule[0]) + def _init_node_l_schedules(self, extra_node_l: int) -> None: + """Parse node degree schedules derived from message-passing schedules.""" + self.extra_node_l = int(extra_node_l) + if self.extra_node_l < 0: + raise ValueError("`extra_node_l` must be non-negative") + self.node_l_schedule = [ + int(l_value) + self.extra_node_l for l_value in self.l_schedule + ] + self.node_ebed_dims = [ + get_so3_dim_of_lmax(l_value) for l_value in self.node_l_schedule + ] + self.node_lmax = int(self.node_l_schedule[0]) + self.node_ebed_dim = int(self.node_ebed_dims[0]) + def _canonicalize_charge_spin( self, charge_spin: torch.Tensor | None, @@ -1821,7 +2010,9 @@ def serialize(self) -> dict[str, Any]: "n_blocks": self.n_blocks, "l_schedule": self.l_schedule, "mmax": self.mmax, + "kmax": self.kmax, "m_schedule": self.m_schedule, + "extra_node_l": self.extra_node_l, "channels": self.channels, "basis_type": self.basis_type, "n_radial": self.n_radial, @@ -1837,6 +2028,7 @@ def serialize(self) -> dict[str, Any]: "focus_dim": self.focus_dim, "ffn_neurons": self.ffn_neurons, "grid_mlp": self.grid_mlp, + "grid_branch": self.grid_branch, "ffn_blocks": self.ffn_blocks, "layer_scale": self.layer_scale, "n_atten_head": self.n_atten_head, @@ -1847,6 +2039,11 @@ def serialize(self) -> dict[str, Any]: "full_attn_res": self.full_attn_res_mode, "block_attn_res": self.block_attn_res_mode, "s2_activation": self.s2_activation, + "ffn_so3_grid": self.ffn_so3_grid, + "node_wise_s2": self.node_wise_s2, + "node_wise_so3": self.node_wise_so3, + "message_node_s2": self.message_node_s2, + "message_node_so3": self.message_node_so3, "lebedev_quadrature": self.lebedev_quadrature, "activation_function": self.activation_function, "glu_activation": self.glu_activation, diff --git a/deepmd/pt/model/descriptor/sezm_nn/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/__init__.py index 9faa82ee97..a834ea1cd2 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/__init__.py +++ b/deepmd/pt/model/descriptor/sezm_nn/__init__.py @@ -8,10 +8,7 @@ from .activation import ( GatedActivation, - S2GridProjector, SwiGLU, - SwiGLUS2Activation, - resolve_s2_grid_resolution, ) from .attention import ( segment_envelope_gated_softmax, @@ -45,7 +42,15 @@ from .ffn import ( EquivariantFFN, ) +from .grid_net import ( + BaseGridNet, + GridBranch, + GridMLP, + S2GridNet, + SO3GridNet, +) from .indexing import ( + build_gie_zonal_index, build_l_major_index, build_m_major_index, build_m_major_l_index, @@ -76,6 +81,13 @@ RMSNorm, ScalarRMSNorm, ) +from .projection import ( + BaseGridProjector, + S2GridProjector, + SO3GridProjector, + resolve_s2_grid_resolution, + resolve_so3_grid, +) from .radial import ( BridgingSwitch, C3CutoffEnvelope, @@ -115,6 +127,8 @@ __all__ = [ "ATTN_RES_MODES", "LEBEDEV_PRECISION_TO_NPOINTS", + "BaseGridNet", + "BaseGridProjector", "BridgingSwitch", "C3CutoffEnvelope", "ChannelLinear", @@ -129,6 +143,8 @@ "ForceEmbedding", "GatedActivation", "GeometricInitialEmbedding", + "GridBranch", + "GridMLP", "InnerClamp", "LoRASO2", "LoRASO3", @@ -136,9 +152,12 @@ "RadialBasis", "RadialMLP", "ReducedEquivariantRMSNorm", + "S2GridNet", "S2GridProjector", "SO2Convolution", "SO2Linear", + "SO3GridNet", + "SO3GridProjector", "SO3Linear", "ScalarRMSNorm", "SeZMDeNSFittingNet", @@ -147,13 +166,13 @@ "SeZMInteractionBlock", "SeZMTypeEmbedding", "SwiGLU", - "SwiGLUS2Activation", "WignerDCalculator", "apply_lora_to_sezm", "build_edge_cache", "build_edge_cache_from_edges", "build_edge_quaternion", "build_edge_type_feat", + "build_gie_zonal_index", "build_l_major_index", "build_m_major_index", "build_m_major_l_index", @@ -179,6 +198,7 @@ "quaternion_to_rotation_matrix", "quaternion_z_rotation", "resolve_s2_grid_resolution", + "resolve_so3_grid", "safe_norm", "safe_numpy_to_tensor", "segment_envelope_gated_softmax", diff --git a/deepmd/pt/model/descriptor/sezm_nn/activation.py b/deepmd/pt/model/descriptor/sezm_nn/activation.py index 1ce567f72b..d732680777 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/activation.py +++ b/deepmd/pt/model/descriptor/sezm_nn/activation.py @@ -1,29 +1,22 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """ -Activation and S2-grid helper modules for SeZM. +Activation helper modules for SeZM. -This module contains SeZM nonlinear operators, including GatedActivation, -point-wise SwiGLU, and the S2-grid projection helper used by the -S2 activation path. +This module contains coefficient-space nonlinear operators, including +GatedActivation and point-wise SwiGLU. Grid projectors and grid nets live in +dedicated modules so coefficient-space and function-space logic remain separate. """ from __future__ import ( annotations, ) -import math from typing import ( Any, ) import torch import torch.nn as nn -import torch.nn.functional as F -from e3nn.o3 import ( - FromS2Grid, - ToS2Grid, - spherical_harmonics, -) from deepmd.dpmodel.utils.seed import ( child_seed, @@ -44,15 +37,9 @@ ) from .indexing import ( - build_l_major_index, - build_m_major_index, build_m_major_l_index, map_degree_idx, ) -from .lebedev import ( - LEBEDEV_PRECISION_TO_NPOINTS, - load_lebedev_rule, -) from .so3 import ( FocusLinear, ) @@ -280,528 +267,4 @@ class SwiGLU(nn.Module): def forward(self, inputs: torch.Tensor) -> torch.Tensor: gate, value = torch.chunk(inputs, chunks=2, dim=-1) - return F.silu(gate) * value - - -class S2GridProjector(nn.Module): - """ - Project SO(3) coefficients to/from a flattened S2 grid. - - Parameters - ---------- - lmax - Maximum spherical harmonic degree. - mmax - Maximum order kept in the coefficient layout. If None, use ``lmax``. - dtype - Buffer dtype used by the projection matrices. - grid_resolution_list - Two-element resolution list. For ``grid_method='e3nn'`` it is - ``[R_phi, R_theta]`` and is converted to the ``e3nn`` - ``(lat, long) = (R_theta, R_phi)`` ordering. For - ``grid_method='lebedev'`` it is ``[precision, n_points]``. - coefficient_layout - Coefficient ordering expected by the caller: - - ``"packed"``: packed ``(l, m)`` order, optionally truncated by ``mmax``. - - ``"m_major"``: reduced m-major order used inside ``SO2Convolution``. - grid_method - S2 quadrature backend. Must be ``"e3nn"`` or ``"lebedev"``. - """ - - def __init__( - self, - *, - lmax: int, - mmax: int | None = None, - dtype: torch.dtype, - grid_resolution_list: list[int] | None = None, - coefficient_layout: str = "packed", - grid_method: str = "e3nn", - ) -> None: - super().__init__() - self.lmax = int(lmax) - self.mmax = int(self.lmax if mmax is None else mmax) - if self.mmax < 0: - raise ValueError("`mmax` must be non-negative") - if self.mmax > self.lmax: - raise ValueError("`mmax` must be <= `lmax`") - self.dtype = dtype - self.device = env.DEVICE - self.coefficient_layout = str(coefficient_layout).lower() - if self.coefficient_layout not in {"packed", "m_major"}: - raise ValueError( - "`coefficient_layout` must be either 'packed' or 'm_major'" - ) - self.grid_method = str(grid_method).lower() - if self.grid_method not in {"e3nn", "lebedev"}: - raise ValueError("`grid_method` must be either 'e3nn' or 'lebedev'") - - self.grid_resolution_list = _normalize_s2_grid_resolution( - self.lmax, - self.mmax, - grid_resolution_list, - method=self.grid_method, - ) - if self.grid_method == "e3nn": - self.phi_resolution, self.theta_resolution = self.grid_resolution_list - self.lebedev_precision = 0 - self.lebedev_npoints = 0 - else: - self.phi_resolution = 0 - self.theta_resolution = 0 - self.lebedev_precision, self.lebedev_npoints = self.grid_resolution_list - - coeff_index = self._build_coefficient_index(device=torch.device("cpu")) - self.coeff_dim = int(coeff_index.numel()) - to_grid_mat, from_grid_mat = self._build_projection_mats(coeff_index) - to_grid_mat = to_grid_mat.to(device=self.device, dtype=self.dtype) - from_grid_mat = from_grid_mat.to(device=self.device, dtype=self.dtype) - self.register_buffer("to_grid_mat", to_grid_mat, persistent=True) - self.register_buffer("from_grid_mat", from_grid_mat, persistent=True) - - def _build_coefficient_index(self, device: torch.device) -> torch.Tensor: - if self.coefficient_layout == "m_major": - return build_m_major_index(self.lmax, self.mmax, device=device) - if self.mmax == self.lmax: - return torch.arange((self.lmax + 1) ** 2, device=device, dtype=torch.long) - return build_l_major_index(self.lmax, self.mmax, device=device) - - def _rescale_truncated_orders(self, mat: torch.Tensor) -> None: - if self.lmax == self.mmax: - return - for l in range(self.lmax + 1): - if l <= self.mmax: - continue - start_idx = l * l - length = 2 * l + 1 - rescale = math.sqrt(length / float(2 * self.mmax + 1)) - mat[:, :, start_idx : start_idx + length].mul_(rescale) - - def _rescale_truncated_matrix(self, mat: torch.Tensor) -> None: - if self.lmax == self.mmax: - return - for l in range(self.lmax + 1): - if l <= self.mmax: - continue - start_idx = l * l - length = 2 * l + 1 - rescale = math.sqrt(length / float(2 * self.mmax + 1)) - mat[:, start_idx : start_idx + length].mul_(rescale) - - def _build_projection_mats( - self, coeff_index: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - if self.grid_method == "lebedev": - return self._build_lebedev_projection_mats(coeff_index) - return self._build_e3nn_projection_mats(coeff_index) - - def _build_e3nn_projection_mats( - self, coeff_index: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - with torch.device("cpu"): - to_grid = ToS2Grid( - self.lmax, - (self.theta_resolution, self.phi_resolution), - normalization="component", - device="cpu", - ) - to_grid_mat = torch.einsum("mbi,am->bai", to_grid.shb, to_grid.sha).detach() - self._rescale_truncated_orders(to_grid_mat) - - from_grid = FromS2Grid( - (self.theta_resolution, self.phi_resolution), - self.lmax, - normalization="component", - device="cpu", - ) - from_grid_mat = torch.einsum( - "am,mbi->bai", from_grid.sha, from_grid.shb - ).detach() - self._rescale_truncated_orders(from_grid_mat) - - to_grid_mat = to_grid_mat.flatten(0, 1).index_select(1, coeff_index) - from_grid_mat = ( - from_grid_mat.flatten(0, 1).permute(1, 0).index_select(0, coeff_index) - ) - return to_grid_mat, from_grid_mat - - def _build_lebedev_projection_mats( - self, coeff_index: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - with torch.device("cpu"): - points, weights = load_lebedev_rule( - self.lebedev_precision, - dtype=torch.float64, - device=torch.device("cpu"), - ) - harmonics = spherical_harmonics( - list(range(self.lmax + 1)), - points, - normalize=True, - normalization="norm", - ) - # e3nn's ``norm`` harmonics are ``component / sqrt(2*l+1)``. - # ``ToS2Grid(..., normalization="component")`` additionally divides - # every degree block by ``sqrt(lmax+1)``; keep the same convention so - # the Lebedev backend can replace the e3nn product-grid backend. - scale = math.sqrt(float(self.lmax + 1)) - degree_factors = harmonics.new_tensor( - [ - float(2 * l + 1) - for l in range(self.lmax + 1) - for _ in range(2 * l + 1) - ] - ) - to_grid_mat = harmonics / scale - # The packaged Lebedev weights sum to one. For ``norm`` harmonics, - # ``sum_a w_a Y_j(a) Y_k(a) = delta_jk / (2*l+1)``; the - # degree_factors and ``scale`` invert this normalization. - from_grid_mat = harmonics * ( - weights[:, None] * scale * degree_factors[None, :] - ) - self._rescale_truncated_matrix(to_grid_mat) - self._rescale_truncated_matrix(from_grid_mat) - - to_grid_mat = to_grid_mat.index_select(1, coeff_index) - from_grid_mat = from_grid_mat.index_select(1, coeff_index).transpose(0, 1) - return to_grid_mat, from_grid_mat - - def to_grid(self, embedding: torch.Tensor) -> torch.Tensor: - """Project coefficients ``(N, D, C)`` to a flattened grid ``(N, A, C)``.""" - return torch.einsum("aj,njc->nac", self.to_grid_mat, embedding) - - def from_grid(self, grid: torch.Tensor) -> torch.Tensor: - """Project a flattened grid ``(N, A, C)`` back to coefficients ``(N, D, C)``.""" - return torch.einsum("ja,nac->njc", self.from_grid_mat, grid) - - def serialize(self) -> dict[str, Any]: - return { - "@class": "S2GridProjector", - "@version": 1, - "config": { - "lmax": self.lmax, - "mmax": self.mmax, - "precision": RESERVED_PRECISION_DICT[self.dtype], - "grid_resolution_list": self.grid_resolution_list, - "coefficient_layout": self.coefficient_layout, - "grid_method": self.grid_method, - }, - "@variables": {}, - } - - @classmethod - def deserialize(cls, data: dict[str, Any]) -> S2GridProjector: - data = data.copy() - data_cls = data.pop("@class") - if data_cls != "S2GridProjector": - raise ValueError(f"Invalid class for S2GridProjector: {data_cls}") - version = int(data.pop("@version")) - check_version_compatibility(version, 1, 1) - config = data.pop("config") - data.pop("@variables", None) - precision = config.pop("precision") - config["dtype"] = PRECISION_DICT[precision] - return cls(**config) - - -class SwiGLUS2Activation(nn.Module): - """ - Apply the merged scalar/grid SwiGLU-S2 activation to SO(3) coefficients. - - The degree-0 slice provides two scalar paths: - - - a scalar ``SwiGLU`` branch that is merged back into the output ``l=0`` part - - a learned sigmoid gate that modulates the full output reconstructed from - the S2 grid path - - The equivariant branch projects the full ``2 * channels`` coefficients to the - S2 grid, multiplies the two channel halves point-wise on the grid, projects - back to coefficients, and applies the scalar sigmoid gate. - - Parameters - ---------- - lmax - Maximum spherical harmonic degree. - mmax - Maximum order kept in the coefficient layout. If None, use ``lmax``. - channels - Output channel count after SwiGLU. The input is expected to have - ``2 * channels`` on the last axis. - dtype - Projection buffer dtype. - n_focus - Number of focus streams in the input layout. - layout - Tensor layout convention: - - ``"ndfc"`` for ``(N, D, F, C)`` - - ``"nfdc"`` for ``(N, F, D, C)`` - grid_resolution_list - Two-element list ``[R_phi, R_theta]``. - coefficient_layout - Coefficient ordering: ``"packed"`` or ``"m_major"``. - grid_method - S2 quadrature backend. Must be ``"e3nn"`` or ``"lebedev"``. - mlp_bias - Whether the scalar sigmoid projection uses bias. - trainable - Whether parameters are trainable. - seed - Random seed for the scalar sigmoid projection. - """ - - def __init__( - self, - *, - lmax: int, - mmax: int | None = None, - channels: int, - dtype: torch.dtype, - n_focus: int = 1, - layout: str = "ndfc", - grid_resolution_list: list[int] | None = None, - coefficient_layout: str = "packed", - grid_method: str = "e3nn", - mlp_bias: bool = False, - trainable: bool, - seed: int | list[int] | None = None, - ) -> None: - super().__init__() - self.lmax = int(lmax) - self.mmax = int(self.lmax if mmax is None else mmax) - self.channels = int(channels) - self.dtype = dtype - self.n_focus = int(n_focus) - self.mlp_bias = bool(mlp_bias) - self.layout = str(layout).lower() - if self.layout not in {"ndfc", "nfdc"}: - raise ValueError("`layout` must be either 'ndfc' or 'nfdc'") - self.coefficient_layout = str(coefficient_layout).lower() - self.grid_method = str(grid_method).lower() - self.grid_resolution_list = _normalize_s2_grid_resolution( - self.lmax, - self.mmax, - grid_resolution_list, - method=self.grid_method, - ) - self.scalar_act = SwiGLU() - self.scalar_gate = FocusLinear( - in_channels=2 * self.channels, - out_channels=self.channels, - n_focus=self.n_focus, - dtype=self.dtype, - bias=self.mlp_bias, - trainable=trainable, - seed=child_seed(seed, 0), - init_std=0.01, - ) - self.projector: S2GridProjector | None - if self.lmax == 0: - self.projector = None - self.coeff_dim = 1 - else: - self.projector = S2GridProjector( - lmax=self.lmax, - mmax=self.mmax, - dtype=self.dtype, - grid_resolution_list=self.grid_resolution_list, - coefficient_layout=self.coefficient_layout, - grid_method=self.grid_method, - ) - self.coeff_dim = self.projector.coeff_dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Parameters - ---------- - x - Input tensor with last dimension ``2 * channels``. - - Returns - ------- - torch.Tensor - Activated tensor with the same coefficient layout and ``channels`` on - the last axis. - """ - input_dtype = x.dtype - # Promote before slicing to avoid the TorchInductor AMP compile bug on - # the scalar SwiGLU branch in PyTorch 2.11. - scalar_inputs = self._extract_scalar_inputs(x.to(dtype=self.dtype)) - scalar_outputs = self.scalar_act(scalar_inputs) - - if self.projector is None: - return self._restore_scalar_outputs(scalar_outputs.to(dtype=input_dtype)) - - gate_scalars = torch.sigmoid(self.scalar_gate(scalar_inputs)) - x_flat, shape_info = self._flatten_inputs(x) - x_grid = self.projector.to_grid(x_flat.to(dtype=self.dtype)) - x_grid_1, x_grid_2 = torch.chunk(x_grid, chunks=2, dim=-1) - out_flat = self.projector.from_grid(x_grid_1 * x_grid_2) - outputs = self._restore_outputs(out_flat, shape_info) - outputs = outputs * self._broadcast_scalar_gate(gate_scalars) - self._merge_scalar_outputs(outputs, scalar_outputs) - return outputs.to(dtype=input_dtype) - - def _extract_scalar_inputs(self, x: torch.Tensor) -> torch.Tensor: - if self.layout == "ndfc": - return x.select(dim=1, index=0) - return x.select(dim=2, index=0) - - def _broadcast_scalar_gate(self, gate_scalars: torch.Tensor) -> torch.Tensor: - if self.layout == "ndfc": - return gate_scalars.unsqueeze(1) - return gate_scalars.unsqueeze(2) - - def _restore_scalar_outputs(self, scalar_outputs: torch.Tensor) -> torch.Tensor: - if self.layout == "ndfc": - return scalar_outputs.unsqueeze(1) - return scalar_outputs.unsqueeze(2) - - def _flatten_inputs( - self, x: torch.Tensor - ) -> tuple[torch.Tensor, tuple[int, int, int]]: - if self.layout == "ndfc": - n_batch, coeff_dim, n_focus, _ = x.shape - return ( - x.permute(0, 2, 1, 3).reshape( - n_batch * n_focus, coeff_dim, x.shape[-1] - ), - (n_batch, coeff_dim, n_focus), - ) - n_batch, n_focus, coeff_dim, _ = x.shape - return ( - x.reshape(n_batch * n_focus, coeff_dim, x.shape[-1]), - (n_batch, coeff_dim, n_focus), - ) - - def _restore_outputs( - self, x: torch.Tensor, shape_info: tuple[int, int, int] - ) -> torch.Tensor: - n_batch, coeff_dim, n_focus = shape_info - if self.layout == "ndfc": - return x.reshape(n_batch, n_focus, coeff_dim, self.channels).permute( - 0, 2, 1, 3 - ) - return x.reshape(n_batch, n_focus, coeff_dim, self.channels) - - def _merge_scalar_outputs( - self, outputs: torch.Tensor, scalar_outputs: torch.Tensor - ) -> None: - if self.layout == "ndfc": - outputs[:, 0, :, :].add_(scalar_outputs) - else: - outputs[:, :, 0, :].add_(scalar_outputs) - - def serialize(self) -> dict[str, Any]: - trainable = all(p.requires_grad for p in self.parameters()) - state = self.state_dict() - return { - "@class": "SwiGLUS2Activation", - "@version": 1, - "config": { - "lmax": self.lmax, - "mmax": self.mmax, - "channels": self.channels, - "precision": RESERVED_PRECISION_DICT[self.dtype], - "n_focus": self.n_focus, - "layout": self.layout, - "grid_resolution_list": self.grid_resolution_list, - "coefficient_layout": self.coefficient_layout, - "grid_method": self.grid_method, - "mlp_bias": self.mlp_bias, - "trainable": trainable, - "seed": None, - }, - "@variables": {key: np_safe(value) for key, value in state.items()}, - } - - @classmethod - def deserialize(cls, data: dict[str, Any]) -> SwiGLUS2Activation: - data = data.copy() - data_cls = data.pop("@class") - if data_cls != "SwiGLUS2Activation": - raise ValueError(f"Invalid class for SwiGLUS2Activation: {data_cls}") - version = int(data.pop("@version")) - check_version_compatibility(version, 1, 1) - config = data.pop("config") - variables = data.pop("@variables") - precision = config.pop("precision") - config["dtype"] = PRECISION_DICT[precision] - obj = cls(**config) - template = obj.state_dict() - state = { - key: safe_numpy_to_tensor( - value, device=template[key].device, dtype=template[key].dtype - ) - for key, value in variables.items() - } - obj.load_state_dict(state) - return obj - - -def resolve_s2_grid_resolution( - lmax: int, - mmax: int, - *, - method: str = "e3nn", -) -> list[int]: - """ - Resolve the default S2 grid resolution. - - For ``method='e3nn'``, the automatic default uses even azimuthal sampling - ``R_phi = 2 * mmax + 4`` and even polar sampling - ``R_theta = ceil_even(3 * lmax + 2)``. - - For ``method='lebedev'``, the automatic default picks the smallest packaged - Lebedev rule whose algebraic precision is at least ``3 * lmax`` and returns - ``[precision, n_points]``. - """ - method = str(method).lower() - if method not in {"e3nn", "lebedev"}: - raise ValueError("`method` must be either 'e3nn' or 'lebedev'") - if method == "lebedev": - required_precision = 3 * int(lmax) - for precision, n_points in LEBEDEV_PRECISION_TO_NPOINTS.items(): - if precision >= required_precision: - return [precision, n_points] - raise ValueError( - f"No packaged Lebedev rule has precision >= {required_precision}" - ) - - phi_resolution = 2 * mmax + 4 - theta_resolution = 3 * lmax + 2 - theta_resolution += theta_resolution % 2 - return [phi_resolution, theta_resolution] - - -def _normalize_s2_grid_resolution( - lmax: int, - mmax: int, - grid_resolution_list: list[int] | None, - *, - method: str, -) -> list[int]: - """Resolve default grids or validate already-resolved low-level grids.""" - method = str(method).lower() - if grid_resolution_list is None: - return resolve_s2_grid_resolution(lmax, mmax, method=method) - if method == "lebedev": - if len(grid_resolution_list) != 2: - raise ValueError( - "Lebedev `grid_resolution_list` must be [precision, n_points]" - ) - precision = int(grid_resolution_list[0]) - n_points = int(grid_resolution_list[1]) - expected_n_points = LEBEDEV_PRECISION_TO_NPOINTS.get(precision) - if expected_n_points != n_points: - raise ValueError( - "Lebedev `grid_resolution_list` must match a packaged " - f"[precision, n_points] pair; got [{precision}, {n_points}]" - ) - return [precision, n_points] - - if len(grid_resolution_list) != 2: - raise ValueError("`grid_resolution_list` must contain two integers") - resolution = [int(grid_resolution_list[0]), int(grid_resolution_list[1])] - if resolution[0] < 1 or resolution[1] < 1: - raise ValueError("grid resolutions must be positive") - return resolution + return gate * torch.sigmoid(gate) * value diff --git a/deepmd/pt/model/descriptor/sezm_nn/block.py b/deepmd/pt/model/descriptor/sezm_nn/block.py index ddc5d9847e..d1ab12ad99 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/block.py +++ b/deepmd/pt/model/descriptor/sezm_nn/block.py @@ -11,6 +11,7 @@ annotations, ) +import os from typing import ( TYPE_CHECKING, Any, @@ -18,6 +19,9 @@ import torch import torch.nn as nn +from torch.utils.checkpoint import ( + checkpoint, +) from deepmd.dpmodel.utils.seed import ( child_seed, @@ -78,9 +82,13 @@ class SeZMInteractionBlock(nn.Module): Parameters ---------- lmax - Maximum spherical harmonic degree. + Maximum message-passing spherical harmonic degree. + node_lmax + Maximum node representation degree. If None, equals `lmax`. mmax Maximum SO(2) order (|m|) mixed inside SO(2) convolution. + kmax + Maximum Wigner-D frame order (|k|) used by SO(3) grid branches. channels Total channels per (l, m) coefficient. n_focus @@ -132,9 +140,27 @@ class SeZMInteractionBlock(nn.Module): If True, apply post-norm on each FFN subblock output before the residual add. ffn_neurons Hidden dimension for each FFN subblock. - grid_mlp - If True, use the optional grid-MLP structure for the block-internal FFN - units. The final descriptor output head is unaffected. + node_wise_grid_mlp + If True, select the polynomial grid MLP operation for the SO(2) + convolution node-wise cross-grid path. + node_wise_grid_branch + Number of scalar-routed polynomial product branches for the node-wise + cross-grid path. ``0`` disables branch mixing; positive values take + precedence over ``node_wise_grid_mlp``. + message_node_grid_mlp + If True, select the polynomial grid MLP operation for the SO(2) + convolution message-node cross-grid path. + message_node_grid_branch + Number of scalar-routed polynomial product branches for the + message-node cross-grid path. ``0`` disables branch mixing; positive + values take precedence over ``message_node_grid_mlp``. + ffn_grid_mlp + If True, select the polynomial grid MLP operation for the + block-internal FFN grid path. + ffn_grid_branch + Number of scalar-routed polynomial product branches for the FFN grid + path. ``0`` disables branch mixing; positive values take precedence + over ``ffn_grid_mlp``. ffn_blocks Number of FFN subblocks per block. layer_scale @@ -153,9 +179,24 @@ class SeZMInteractionBlock(nn.Module): so2_s2_activation If True, enable the merged scalar/grid SwiGLU-S2 activation in the SO(2) branch. + node_wise_s2 + If True, enable the edge-local source-destination S2 product branch in + the SO(2) convolution. + node_wise_so3 + If True, enable the corresponding edge-local SO(3) Wigner-D grid branch + in the SO(2) convolution. + message_node_s2 + If True, enable the post-aggregation message-node S2 product branch in + the SO(2) convolution. + message_node_so3 + If True, enable the corresponding post-aggregation SO(3) Wigner-D grid + branch in the SO(2) convolution. ffn_s2_activation If True, enable the merged scalar/grid SwiGLU-S2 activation in the default FFN activation path. + ffn_so3_grid + If True, use the SO(3) Wigner-D grid in the block-internal FFN. This + takes precedence over ``ffn_s2_activation``. so2_lebedev_quadrature If True, use Lebedev quadrature for the SO(2) S2 activation projector. ffn_lebedev_quadrature @@ -190,7 +231,9 @@ def __init__( self, *, lmax: int, + node_lmax: int | None = None, mmax: int | None = None, + kmax: int = 1, channels: int, n_focus: int = 1, focus_dim: int = 0, @@ -209,13 +252,23 @@ def __init__( ffn_pre_norm: bool = True, ffn_post_norm: bool = False, ffn_neurons: int = 96, - grid_mlp: bool = False, + node_wise_grid_mlp: bool = False, + node_wise_grid_branch: int = 0, + message_node_grid_mlp: bool = False, + message_node_grid_branch: int = 0, + ffn_grid_mlp: bool = False, + ffn_grid_branch: int = 0, ffn_blocks: int = 1, layer_scale: bool = False, full_attn_res: str = "none", block_attn_res: str = "none", so2_s2_activation: bool = False, + node_wise_s2: bool = False, + node_wise_so3: bool = False, + message_node_s2: bool = False, + message_node_so3: bool = False, ffn_s2_activation: bool = False, + ffn_so3_grid: bool = False, so2_lebedev_quadrature: bool = False, ffn_lebedev_quadrature: bool = False, so2_activation_function: str = "silu", @@ -230,11 +283,19 @@ def __init__( ) -> None: super().__init__() self.lmax = int(lmax) + self.node_lmax = self.lmax if node_lmax is None else int(node_lmax) + if self.node_lmax < self.lmax: + raise ValueError("`node_lmax` must be >= `lmax`") + self.mp_ebed_dim = (self.lmax + 1) ** 2 + self.node_ebed_dim = (self.node_lmax + 1) ** 2 self.mmax = int(self.lmax if mmax is None else mmax) if self.mmax < 0: raise ValueError("`mmax` must be non-negative") if self.mmax > self.lmax: raise ValueError("`mmax` must be <= `lmax`") + self.kmax = int(kmax) + if self.kmax < 0: + raise ValueError("`kmax` must be non-negative") self.channels = int(channels) self.n_focus = int(n_focus) if self.n_focus < 1: @@ -261,7 +322,21 @@ def __init__( self.ffn_pre_norm = bool(ffn_pre_norm) self.ffn_post_norm = bool(ffn_post_norm) self.ffn_neurons = int(ffn_neurons) - self.grid_mlp = bool(grid_mlp) + self.node_wise_grid_mlp = bool(node_wise_grid_mlp) + self.node_wise_grid_branch = int(node_wise_grid_branch) + self.message_node_grid_mlp = bool(message_node_grid_mlp) + self.message_node_grid_branch = int(message_node_grid_branch) + self.ffn_grid_mlp = bool(ffn_grid_mlp) + self.ffn_grid_branch = int(ffn_grid_branch) + if ( + min( + self.node_wise_grid_branch, + self.message_node_grid_branch, + self.ffn_grid_branch, + ) + < 0 + ): + raise ValueError("grid branch counts must be non-negative") self.ffn_blocks = int(ffn_blocks) if self.ffn_blocks < 1: raise ValueError("`ffn_blocks` must be >= 1") @@ -283,7 +358,12 @@ def __init__( "`full_attn_res` and `block_attn_res` cannot both be enabled" ) self.so2_s2_activation = bool(so2_s2_activation) + self.node_wise_s2 = bool(node_wise_s2) + self.node_wise_so3 = bool(node_wise_so3) + self.message_node_s2 = bool(message_node_s2) + self.message_node_so3 = bool(message_node_so3) self.ffn_s2_activation = bool(ffn_s2_activation) + self.ffn_so3_grid = bool(ffn_so3_grid) self.so2_lebedev_quadrature = bool(so2_lebedev_quadrature) self.ffn_lebedev_quadrature = bool(ffn_lebedev_quadrature) self.so2_activation_function = str(so2_activation_function) @@ -329,6 +409,7 @@ def __init__( self.so2_conv = SO2Convolution( lmax=self.lmax, mmax=self.mmax, + kmax=self.kmax, channels=self.channels, n_focus=self.n_focus, focus_dim=self.focus_dim, @@ -344,6 +425,14 @@ def __init__( atten_v_proj=self.use_atten_v_proj, atten_o_proj=self.use_atten_o_proj, s2_activation=self.so2_s2_activation, + node_wise_grid_mlp=self.node_wise_grid_mlp, + node_wise_grid_branch=self.node_wise_grid_branch, + message_node_grid_mlp=self.message_node_grid_mlp, + message_node_grid_branch=self.message_node_grid_branch, + node_wise_s2=self.node_wise_s2, + node_wise_so3=self.node_wise_so3, + message_node_s2=self.message_node_s2, + message_node_so3=self.message_node_so3, lebedev_quadrature=self.so2_lebedev_quadrature, activation_function=self.so2_activation_function, mlp_bias=self.mlp_bias, @@ -365,7 +454,7 @@ def __init__( if self.ffn_pre_norm: pre_ffn_norms.append( EquivariantRMSNorm( - self.lmax, + self.node_lmax, self.channels, n_focus=1, dtype=self.compute_dtype, @@ -378,7 +467,7 @@ def __init__( if self.ffn_post_norm: post_ffn_norms.append( EquivariantRMSNorm( - self.lmax, + self.node_lmax, self.channels, n_focus=1, dtype=self.compute_dtype, @@ -390,12 +479,15 @@ def __init__( ffns.append( EquivariantFFN( - lmax=self.lmax, + lmax=self.node_lmax, channels=self.channels, hidden_channels=ffn_neurons, - grid_mlp=self.grid_mlp, + kmax=self.kmax, + grid_mlp=self.ffn_grid_mlp, + grid_branch=self.ffn_grid_branch, dtype=dtype, s2_activation=self.ffn_s2_activation, + ffn_so3_grid=self.ffn_so3_grid, lebedev_quadrature=self.ffn_lebedev_quadrature, activation_function=self.ffn_activation_function, glu_activation=self.ffn_glu_activation, @@ -543,6 +635,22 @@ def _extract_l0_from_canonical(self, value: torch.Tensor) -> torch.Tensor: """ return value[:, 0, :, :].reshape(value.shape[0], self.channels) + def _use_infer_activation_checkpoint(self, *tensors: torch.Tensor) -> bool: + """Return whether eval-time activation checkpointing should be used. + + Disabled on the compiled inference path (``DP_COMPILE_INFER``): Inductor + already reuses activation buffers, so recomputation only adds latency for + a negligible memory gain there. + """ + return ( + not self.training + and os.environ.get("DP_ACT_INFER") == "1" + and os.environ.get("DP_COMPILE_INFER", "").strip().lower() + not in {"1", "true", "yes", "on"} + and torch.is_grad_enabled() + and any(tensor.requires_grad for tensor in tensors) + ) + def _run_so2_unit( self, x: torch.Tensor, @@ -566,14 +674,45 @@ def _run_so2_unit( torch.Tensor SO(2) unit output with shape `(N, D, 1, C)`. """ + if self._use_infer_activation_checkpoint(x, radial_feat): + edge_cache_no_proj = edge_cache._replace( + D_to_m_cache=None, + Dt_from_m_cache=None, + ) + return checkpoint( + lambda x_, radial_feat_: self._run_so2_unit_impl( + x_, + edge_cache_no_proj, + radial_feat_, + ), + x, + radial_feat, + use_reentrant=False, + preserve_rng_state=True, + ) + return self._run_so2_unit_impl(x, edge_cache, radial_feat) + + def _run_so2_unit_impl( + self, + x: torch.Tensor, + edge_cache: EdgeFeatureCache, + radial_feat: torch.Tensor, + ) -> torch.Tensor: + """Run the SO(2) unit implementation.""" n_node = x.shape[0] - ebed_dim = x.shape[1] channels = self.channels - x_pre = self.pre_so2_norm(x) + use_full_node = self.node_lmax == self.lmax + x_so2 = x if use_full_node else x[:, : self.mp_ebed_dim, :, :] + x_pre = self.pre_so2_norm(x_so2) so2_unit_output = self.so2_conv( - x_pre.reshape(n_node, ebed_dim, channels), edge_cache, radial_feat + x_pre.reshape(n_node, x_so2.shape[1], channels), edge_cache, radial_feat ) - return self.post_so2_norm(so2_unit_output.unsqueeze(2)) + so2_unit_output = self.post_so2_norm(so2_unit_output.unsqueeze(2)) + if use_full_node: + return so2_unit_output + output = x.new_zeros(x.shape) + output[:, : self.mp_ebed_dim, :, :] = so2_unit_output + return output def _run_ffn_unit(self, x: torch.Tensor, unit_idx: int) -> torch.Tensor: """ @@ -591,6 +730,17 @@ def _run_ffn_unit(self, x: torch.Tensor, unit_idx: int) -> torch.Tensor: torch.Tensor FFN unit output with shape `(N, D, 1, C)`. """ + if self._use_infer_activation_checkpoint(x): + return checkpoint( + lambda x_: self._run_ffn_unit_impl(x_, unit_idx), + x, + use_reentrant=False, + preserve_rng_state=True, + ) + return self._run_ffn_unit_impl(x, unit_idx) + + def _run_ffn_unit_impl(self, x: torch.Tensor, unit_idx: int) -> torch.Tensor: + """Run one FFN subblock implementation.""" n_node = x.shape[0] ebed_dim = x.shape[1] channels = self.channels @@ -773,7 +923,9 @@ def serialize(self) -> dict[str, Any]: "@version": 1, "config": { "lmax": self.lmax, + "node_lmax": self.node_lmax, "mmax": self.mmax, + "kmax": self.kmax, "channels": self.channels, "n_focus": self.n_focus, "focus_dim": self.focus_dim, @@ -792,12 +944,22 @@ def serialize(self) -> dict[str, Any]: "ffn_pre_norm": self.ffn_pre_norm, "ffn_post_norm": self.ffn_post_norm, "ffn_neurons": self.ffn_neurons, - "grid_mlp": self.grid_mlp, + "node_wise_grid_mlp": self.node_wise_grid_mlp, + "node_wise_grid_branch": self.node_wise_grid_branch, + "message_node_grid_mlp": self.message_node_grid_mlp, + "message_node_grid_branch": self.message_node_grid_branch, + "ffn_grid_mlp": self.ffn_grid_mlp, + "ffn_grid_branch": self.ffn_grid_branch, "ffn_blocks": self.ffn_blocks, "full_attn_res": self.full_attn_res_mode, "block_attn_res": self.block_attn_res_mode, "so2_s2_activation": self.so2_s2_activation, + "node_wise_s2": self.node_wise_s2, + "node_wise_so3": self.node_wise_so3, + "message_node_s2": self.message_node_s2, + "message_node_so3": self.message_node_so3, "ffn_s2_activation": self.ffn_s2_activation, + "ffn_so3_grid": self.ffn_so3_grid, "so2_lebedev_quadrature": self.so2_lebedev_quadrature, "ffn_lebedev_quadrature": self.ffn_lebedev_quadrature, "so2_activation_function": self.so2_activation_function, diff --git a/deepmd/pt/model/descriptor/sezm_nn/embedding.py b/deepmd/pt/model/descriptor/sezm_nn/embedding.py index 4b298e3483..e31d7e2b65 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/embedding.py +++ b/deepmd/pt/model/descriptor/sezm_nn/embedding.py @@ -40,8 +40,8 @@ ) from .indexing import ( + build_gie_zonal_index, get_so3_dim_of_lmax, - map_degree_idx, ) from .utils import ( np_safe, @@ -154,7 +154,7 @@ class GeometricInitialEmbedding(nn.Module): Parameters ---------- lmax - Maximum degree, should match ``l_schedule[0]``. + Maximum node degree for the initial embedding. channels Number of channels per (l, m) coefficient. dtype @@ -175,49 +175,24 @@ def __init__( self.dtype = dtype self.device = env.DEVICE self.precision = RESERVED_PRECISION_DICT[dtype] - if self.lmax > 0: - packed_degree_by_row = map_degree_idx(self.lmax, device=self.device) - # These aligned arrays describe one packed non-scalar row at a time. - # non_scalar_row_index[k] picks the output row in the packed SO(3) layout. - # zonal_m0_col_index_for_row[k] picks the matching m=0 column in Dt_full. - # radial_slot_index_for_row[k] picks the matching degree slot in radial_feat. - non_scalar_row_index = torch.arange( - 1, self.ebed_dim, device=self.device, dtype=torch.long - ) - non_scalar_degree_by_row = packed_degree_by_row[1:] - zonal_m0_col_index_for_row = non_scalar_degree_by_row * ( - non_scalar_degree_by_row + 1 - ) - radial_slot_index_for_row = non_scalar_degree_by_row - 1 - self.register_buffer( - "non_scalar_row_index", non_scalar_row_index, persistent=True - ) - self.register_buffer( - "zonal_m0_col_index_for_row", - zonal_m0_col_index_for_row, - persistent=True, - ) - self.register_buffer( - "radial_slot_index_for_row", - radial_slot_index_for_row, - persistent=True, - ) - else: - self.register_buffer( - "non_scalar_row_index", - torch.empty(0, device=self.device, dtype=torch.long), - persistent=True, - ) - self.register_buffer( - "zonal_m0_col_index_for_row", - torch.empty(0, device=self.device, dtype=torch.long), - persistent=True, - ) - self.register_buffer( - "radial_slot_index_for_row", - torch.empty(0, device=self.device, dtype=torch.long), - persistent=True, - ) + ( + node_row_index, + node_zonal_m0_col_index, + node_radial_l_index, + ) = build_gie_zonal_index(self.lmax, device=self.device) + # One aligned entry per non-scalar node row: output row, local m=0 + # column, and the matching radial degree slot. + self.register_buffer("non_scalar_row_index", node_row_index, persistent=True) + self.register_buffer( + "zonal_m0_col_index_for_row", + node_zonal_m0_col_index, + persistent=True, + ) + self.register_buffer( + "radial_slot_index_for_row", + node_radial_l_index, + persistent=True, + ) def forward( self, @@ -225,6 +200,7 @@ def forward( n_nodes: int, edge_cache: EdgeFeatureCache, radial_feat: torch.Tensor, + zonal_coupling: torch.Tensor | None = None, ) -> torch.Tensor: """ Parameters @@ -235,6 +211,9 @@ def forward( Per-edge cache containing geometry, weights, and Wigner-D blocks. radial_feat Per-edge radial features with shape (E, lmax, C) for l=1..lmax. + zonal_coupling + Optional precomputed zonal coupling with shape (E, D-1). If None, + it is gathered from ``edge_cache.Dt_full``. Returns ------- @@ -253,12 +232,13 @@ def forward( # === Step 2. Gather all m=0 columns (l >= 1) in one shot === # Advanced indexing pairs one packed non-scalar row with the zonal m=0 column # from the same degree block in Dt_full. - Dt_full = edge_cache.Dt_full # (E, D, D) - zonal_m0_value_for_row = Dt_full[ - :, - self.non_scalar_row_index, - self.zonal_m0_col_index_for_row, - ] # (E, D-1) + if zonal_coupling is None: + Dt_full = edge_cache.Dt_full # (E, D, D) + zonal_coupling = Dt_full[ + :, + self.non_scalar_row_index, + self.zonal_m0_col_index_for_row, + ] # (E, D-1) # === Step 3. Broadcast radial features per row === # Each non-scalar packed row reuses the radial feature of its degree l. @@ -266,7 +246,7 @@ def forward( 1, self.radial_slot_index_for_row ) # (E, D-1, C) non_scalar_message = ( - zonal_m0_value_for_row.unsqueeze(-1) * radial_value_for_row + zonal_coupling.unsqueeze(-1) * radial_value_for_row ) # (E, D-1, C) # === Step 4. Source Freeze Propagation Gate (optional) === diff --git a/deepmd/pt/model/descriptor/sezm_nn/ffn.py b/deepmd/pt/model/descriptor/sezm_nn/ffn.py index 0e06162bf5..0e9163bf6d 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/ffn.py +++ b/deepmd/pt/model/descriptor/sezm_nn/ffn.py @@ -33,13 +33,15 @@ from .activation import ( GatedActivation, - S2GridProjector, - SwiGLU, - SwiGLUS2Activation, +) +from .grid_net import ( + S2GridNet, + SO3GridNet, +) +from .projection import ( resolve_s2_grid_resolution, ) from .so3 import ( - ChannelLinear, SO3Linear, ) from .utils import ( @@ -59,10 +61,10 @@ class EquivariantFFN(nn.Module): Default structure (glu_activation=True): SO3 linear (in -> 2*hidden) -> split -> GatedActivation(val, gate) -> SO3 linear (hidden -> out) - Optional grid-FFN structure (grid_mlp=True): + Optional grid-FFN structure (s2_activation=True or ffn_so3_grid=True): SO3 linear (in -> hidden) - -> project packed SO(3) coefficients to the S2 grid - -> packed S2-grid point-wise MLP on hidden features + -> project packed SO(3) coefficients to the S2 or SO3 grid + -> grid GLU, polynomial MLP, or scalar-routed attention on hidden features -> project grid features back to packed SO(3) coefficients -> add scalar LinearSwiGLU branch to l=0 -> SO3 linear (hidden -> out) @@ -84,15 +86,21 @@ class EquivariantFFN(nn.Module): Number of channels per (l, m) coefficient. hidden_channels Hidden dimension for the FFN. + kmax + Maximum Wigner-D frame order (|k|) used by the SO3 Wigner-D FFN grid. grid_mlp - If True, use the optional grid-MLP FFN structure on the block-internal - FFN path. This path takes precedence over the simpler activation-only - path inside this module. + If True, select the polynomial grid MLP operation when the + block-internal FFN grid path is enabled. + grid_branch + Number of scalar-routed polynomial product branches used when the + block-internal FFN grid path is enabled. ``0`` disables this branch + mixer. Positive values take precedence over ``grid_mlp``. dtype Parameter dtype. s2_activation - If True and ``grid_mlp=False``, replace the default GatedActivation path - with the merged scalar/grid SwiGLU-S2 activation. + If True, enable the S2 FFN grid path. + ffn_so3_grid + If True, enable the SO3 Wigner-D FFN grid path. lebedev_quadrature If True, use Lebedev quadrature for the S2 projector in this FFN. activation_function @@ -115,9 +123,12 @@ def __init__( lmax: int, channels: int, hidden_channels: int, + kmax: int = 1, grid_mlp: bool = False, + grid_branch: int = 0, dtype: torch.dtype, s2_activation: bool = False, + ffn_so3_grid: bool = False, lebedev_quadrature: bool = False, activation_function: str = "silu", glu_activation: bool = True, @@ -129,8 +140,16 @@ def __init__( self.lmax = int(lmax) self.channels = int(channels) self.hidden_channels = int(hidden_channels) + self.kmax = int(kmax) + if self.kmax < 0: + raise ValueError("`kmax` must be non-negative") self.use_grid_mlp = bool(grid_mlp) + self.grid_branch = int(grid_branch) + if self.grid_branch < 0: + raise ValueError("`grid_branch` must be non-negative") + self.use_grid_branch = self.grid_branch > 0 self.s2_activation = bool(s2_activation) + self.ffn_so3_grid = bool(ffn_so3_grid) self.lebedev_quadrature = bool(lebedev_quadrature) self.s2_grid_method = "lebedev" if self.lebedev_quadrature else "e3nn" base_grid = resolve_s2_grid_resolution( @@ -150,6 +169,7 @@ def __init__( self.compute_dtype = get_promoted_dtype(self.dtype) self.device = env.DEVICE self.precision = RESERVED_PRECISION_DICT[dtype] + self.grid_n_frames = 2 * self.kmax + 1 if self.ffn_so3_grid else 1 # === Step 0. Split deterministic seeds at the module top-level === seed_so3_in = child_seed(seed, 0) @@ -157,10 +177,11 @@ def __init__( seed_so3_out = child_seed(seed, 2) # === First SO3Linear for channel mixing === - # Grid-FFN keeps the hidden width and performs the nonlinear expansion - # inside the scalar/grid point-wise MLPs. + self.use_grid_net = self.s2_activation or self.ffn_so3_grid linear1_out_channels = self.hidden_channels - if not self.use_grid_mlp: + if self.use_grid_net: + linear1_out_channels = 2 * self.grid_n_frames * self.hidden_channels + else: linear1_out_channels = ( 2 * self.hidden_channels if self.glu_activation @@ -178,50 +199,44 @@ def __init__( ) # === Equivariant nonlinearity path === - self.scalar_mlp: nn.Module | None = None - self.grid_projector: S2GridProjector | None = None - self.pointwise_grid_mlp: nn.Module | None = None - if self.use_grid_mlp: - self.scalar_mlp = nn.Sequential( - ChannelLinear( - in_channels=self.channels, - out_channels=2 * self.hidden_channels, - dtype=dtype, - bias=self.mlp_bias, - trainable=trainable, - seed=child_seed(seed_act, 0), - ), - SwiGLU(), - ) - self.grid_projector = S2GridProjector( - lmax=self.lmax, - mmax=self.lmax, - dtype=dtype, - grid_resolution_list=self.s2_grid_resolution, - coefficient_layout="packed", - grid_method=self.s2_grid_method, - ) - self.pointwise_grid_mlp = PointwiseGridMLP( - channels=self.hidden_channels, - dtype=dtype, - trainable=trainable, - seed=child_seed(seed_act, 1), - ) - self.act = nn.Identity() - elif self.s2_activation: - self.act = SwiGLUS2Activation( - lmax=self.lmax, - channels=self.hidden_channels, - dtype=self.compute_dtype, - n_focus=1, - layout="ndfc", - grid_resolution_list=self.s2_grid_resolution, - coefficient_layout="packed", - grid_method=self.s2_grid_method, - mlp_bias=self.mlp_bias, - trainable=trainable, - seed=seed_act, + if self.use_grid_net: + grid_op = ( + "branch" + if self.use_grid_branch + else ("mlp" if self.use_grid_mlp else "glu") ) + if self.ffn_so3_grid: + self.act = SO3GridNet( + lmax=self.lmax, + kmax=self.kmax, + channels=self.hidden_channels, + n_focus=1, + mode="self", + op_type=grid_op, + dtype=self.compute_dtype, + layout="ndfc", + grid_branches=max(1, self.grid_branch), + mlp_bias=self.mlp_bias, + trainable=trainable, + seed=seed_act, + ) + else: + self.act = S2GridNet( + lmax=self.lmax, + channels=self.hidden_channels, + n_focus=1, + mode="self", + op_type=grid_op, + dtype=self.compute_dtype, + layout="ndfc", + grid_resolution_list=self.s2_grid_resolution, + coefficient_layout="packed", + grid_method=self.s2_grid_method, + grid_branches=max(1, self.grid_branch), + mlp_bias=self.mlp_bias, + trainable=trainable, + seed=seed_act, + ) else: self.act = GatedActivation( lmax=self.lmax, @@ -238,7 +253,7 @@ def __init__( # Zero-initialized so residual path starts near-identity. self.so3_linear_2 = SO3Linear( lmax=self.lmax, - in_channels=self.hidden_channels, + in_channels=self.grid_n_frames * self.hidden_channels, out_channels=self.channels, n_focus=1, dtype=dtype, @@ -264,20 +279,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Output with shape (N, D, F, C). """ # === Step 1. Input up projection === - x_input = x x = self.so3_linear_1(x) # === Step 2. Equivariant nonlinearity === - if self.use_grid_mlp: - scalar_outputs = self.scalar_mlp(x_input.select(dim=1, index=0)) - x_flat, shape_info = self._flatten_grid_inputs(x) - x_grid = self.grid_projector.to_grid(x_flat.to(dtype=self.dtype)) - x_grid = self.pointwise_grid_mlp(x_grid) - x = self._restore_grid_outputs( - self.grid_projector.from_grid(x_grid), shape_info - ) - x[:, 0, :, :].add_(scalar_outputs) - elif self.s2_activation: + if self.use_grid_net: x = self.act(x) elif self.glu_activation: # Split into value and gate branches along channel dimension @@ -292,23 +297,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x - def _flatten_grid_inputs( - self, x: torch.Tensor - ) -> tuple[torch.Tensor, tuple[int, int, int]]: - n_batch, coeff_dim, n_focus, _ = x.shape - return ( - x.permute(0, 2, 1, 3).reshape(n_batch * n_focus, coeff_dim, x.shape[-1]), - (n_batch, coeff_dim, n_focus), - ) - - def _restore_grid_outputs( - self, x: torch.Tensor, shape_info: tuple[int, int, int] - ) -> torch.Tensor: - n_batch, coeff_dim, n_focus = shape_info - return x.reshape(n_batch, n_focus, coeff_dim, self.hidden_channels).permute( - 0, 2, 1, 3 - ) - def serialize(self) -> dict[str, Any]: trainable = all(p.requires_grad for p in self.parameters()) state = self.state_dict() @@ -319,9 +307,12 @@ def serialize(self) -> dict[str, Any]: "lmax": self.lmax, "channels": self.channels, "hidden_channels": self.hidden_channels, + "kmax": self.kmax, "grid_mlp": self.use_grid_mlp, + "grid_branch": self.grid_branch, "precision": RESERVED_PRECISION_DICT[self.dtype], "s2_activation": self.s2_activation, + "ffn_so3_grid": self.ffn_so3_grid, "lebedev_quadrature": self.lebedev_quadrature, "activation_function": self.activation_function, "glu_activation": self.glu_activation, @@ -354,53 +345,3 @@ def deserialize(cls, data: dict[str, Any]) -> EquivariantFFN: } obj.load_state_dict(state) return obj - - -class PointwiseGridMLP(nn.Module): - """ - Apply a two-layer point-wise MLP on flattened S2 grid features. - - Parameters - ---------- - channels - Hidden feature dimension on the grid. - dtype - Parameter dtype. - trainable - Whether parameters are trainable. - seed - Random seed for weight initialization. - """ - - def __init__( - self, - *, - channels: int, - dtype: torch.dtype, - trainable: bool, - seed: int | list[int] | None = None, - ) -> None: - super().__init__() - self.channels = int(channels) - self.linear_1 = ChannelLinear( - in_channels=self.channels, - out_channels=2 * self.channels, - dtype=dtype, - bias=False, - trainable=trainable, - seed=child_seed(seed, 0), - ) - self.act = SwiGLU() - self.linear_2 = ChannelLinear( - in_channels=self.channels, - out_channels=self.channels, - dtype=dtype, - bias=False, - trainable=trainable, - seed=child_seed(seed, 1), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply the point-wise grid MLP.""" - x = self.act(self.linear_1(x)) - return self.linear_2(x) diff --git a/deepmd/pt/model/descriptor/sezm_nn/grid_net.py b/deepmd/pt/model/descriptor/sezm_nn/grid_net.py new file mode 100644 index 0000000000..89d79d13d5 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/grid_net.py @@ -0,0 +1,757 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Grid-space nonlinearities for SeZM coefficient tensors. + +A grid net receives coefficient tensors, converts them to quadrature values, +applies one point-wise grid operation, and projects the result back to +coefficients. The public shapes are: + +* ``mode='self'``: one input ``(N, D, F, 2*C)`` or ``(N, F, D, 2*C)``. +* ``mode='cross'``: query and context inputs with separate ``C`` channels. +* grid values: ``(N, G, F, C)`` after S2 or SO3 projection. + +The only nonlinear scalar functions are SwiGLU, sigmoid, and softmax on the +``l=0`` scalar branch. Non-scalar grid values use channel-linear maps and +point-wise products so equivariance is governed by the projector quadrature. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + Literal, +) + +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + get_generator, +) + +from .activation import ( + SwiGLU, +) +from .indexing import ( + build_m_major_l_index, + map_degree_idx, +) +from .projection import ( + BaseGridProjector, + S2GridProjector, + SO3GridProjector, +) +from .so3 import ( + ChannelLinear, + FocusLinear, +) + +GridNetLayout = Literal["ndfc", "nfdc", "flat"] +GridNetMode = Literal["self", "cross"] +GridNetOp = Literal["glu", "mlp", "branch"] + + +def _build_frame_degree_index( + *, + lmax: int, + mmax: int, + coefficient_layout: str, +) -> torch.Tensor: + """Build the per-coefficient degree index used by frame channel mixers.""" + coefficient_layout = str(coefficient_layout).lower() + if coefficient_layout == "m_major": + return build_m_major_l_index(lmax, mmax, device=env.DEVICE) + if coefficient_layout == "packed": + return map_degree_idx(lmax, device=env.DEVICE) + raise ValueError("`coefficient_layout` must be either 'packed' or 'm_major'") + + +class GridMLP(nn.Module): + """Polynomial point-wise MLP applied independently at every grid point.""" + + def __init__( + self, + *, + channels: int, + mode: GridNetMode, + dtype: torch.dtype, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.channels = int(channels) + self.mode = str(mode).lower() + if self.mode not in {"self", "cross"}: + raise ValueError("`mode` must be either 'self' or 'cross'") + self.input_channels = ( + 2 * self.channels if self.mode == "self" else self.channels + ) + self.hidden_channels = 2 * self.channels + self.left_proj = ChannelLinear( + in_channels=self.input_channels, + out_channels=self.hidden_channels, + dtype=dtype, + bias=False, + trainable=trainable, + seed=child_seed(seed, 0), + ) + self.right_proj = ChannelLinear( + in_channels=self.input_channels, + out_channels=self.hidden_channels, + dtype=dtype, + bias=False, + trainable=trainable, + seed=child_seed(seed, 1), + ) + self.out_proj = ChannelLinear( + in_channels=self.hidden_channels, + out_channels=self.channels, + dtype=dtype, + bias=False, + trainable=trainable, + seed=child_seed(seed, 2), + ) + + def forward( + self, query_grid: torch.Tensor, context_grid: torch.Tensor + ) -> torch.Tensor: + """ + Apply the point-wise polynomial MLP to ``(N, G, F, C)`` grid fields. + + In self mode, both projections see ``concat(query_grid, context_grid)`` + and can form self and cross quadratic channel terms. In cross mode, + the query and context roles stay separate: + ``(W_q query_grid) * (W_c context_grid)``. + """ + if self.mode == "self": + grid = torch.cat([query_grid, context_grid], dim=-1) + left = self.left_proj(grid) + right = self.right_proj(grid) + else: + left = self.left_proj(query_grid) + right = self.right_proj(context_grid) + return self.out_proj(left * right) + + +class GridBranch(nn.Module): + """ + Scalar-routed polynomial mixer over grid product branches. + + The softmax sees only invariant scalar inputs. Each branch is a + quadratic product of grid fields, so rotations only act through the grid + argument and the operation remains as band-limited as the product path. + """ + + def __init__( + self, + *, + channels: int, + n_branches: int, + dtype: torch.dtype, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.channels = int(channels) + self.n_branches = int(n_branches) + if self.n_branches < 1: + raise ValueError("`n_branches` must be positive") + self.left_proj = ChannelLinear( + in_channels=self.channels, + out_channels=self.n_branches * self.channels, + dtype=dtype, + bias=False, + trainable=trainable, + seed=child_seed(seed, 0), + ) + self.right_proj = ChannelLinear( + in_channels=self.channels, + out_channels=self.n_branches * self.channels, + dtype=dtype, + bias=False, + trainable=trainable, + seed=child_seed(seed, 1), + ) + self.router = ChannelLinear( + in_channels=2 * self.channels, + out_channels=self.n_branches, + dtype=dtype, + bias=False, + trainable=trainable, + seed=child_seed(seed, 2), + ) + self.out_proj = ChannelLinear( + in_channels=self.channels, + out_channels=self.channels, + dtype=dtype, + bias=False, + trainable=trainable, + seed=child_seed(seed, 3), + ) + + def forward( + self, + query_grid: torch.Tensor, + context_grid: torch.Tensor, + scalar_pair: torch.Tensor, + ) -> torch.Tensor: + """ + Apply scalar-routed grid branch mixing. + + Parameters + ---------- + query_grid + First grid source with shape ``(N, G, F, C)``. + context_grid + Second grid source with shape ``(N, G, F, C)``. + scalar_pair + Invariant router source with shape ``(N, F, 2*C)``. + """ + n_batch, n_grid, n_focus, _ = query_grid.shape + left = self.left_proj(query_grid) + right = self.right_proj(context_grid) + value = (left * right).reshape( + n_batch, + n_grid, + n_focus, + self.n_branches, + self.channels, + ) # (N, G, F, N_branches, C) + router = torch.softmax(self.router(scalar_pair), dim=-1) # (N, F, N_branches) + out = torch.einsum("ngfhc,nfh->ngfc", value, router) # (N, G, F, C) + return self.out_proj(out) + + +class FrameContract(nn.Module): + """Per-degree frame/channel contraction that preserves the order index.""" + + def __init__( + self, + *, + lmax: int, + mmax: int, + coefficient_layout: str, + n_frames: int, + channels: int, + dtype: torch.dtype, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(mmax) + self.coefficient_layout = str(coefficient_layout).lower() + self.n_frames = int(n_frames) + self.channels = int(channels) + degree_index = _build_frame_degree_index( + lmax=self.lmax, + mmax=self.mmax, + coefficient_layout=self.coefficient_layout, + ) + self.register_buffer("degree_index", degree_index, persistent=False) + self.weight = nn.Parameter( + torch.empty( + self.lmax + 1, + self.n_frames * self.channels, + self.channels, + dtype=dtype, + device=env.DEVICE, + ) + ) + bound = 1.0 / (self.n_frames * self.channels) ** 0.5 + nn.init.uniform_(self.weight, -bound, bound, generator=get_generator(seed)) + for param in self.parameters(): + param.requires_grad = trainable + + def forward(self, coeff: torch.Tensor) -> torch.Tensor: + """Contract ``(N, D, F, K*C)`` frame coefficients to ``(N, D, F, C)``.""" + weight = self.weight.index_select(0, self.degree_index) + return torch.einsum("ndfi,dio->ndfo", coeff, weight) + + +class FrameExpand(nn.Module): + """Per-degree frame/channel expansion that preserves the order index.""" + + def __init__( + self, + *, + lmax: int, + mmax: int, + coefficient_layout: str, + n_frames: int, + channels: int, + dtype: torch.dtype, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(mmax) + self.coefficient_layout = str(coefficient_layout).lower() + self.n_frames = int(n_frames) + self.channels = int(channels) + degree_index = _build_frame_degree_index( + lmax=self.lmax, + mmax=self.mmax, + coefficient_layout=self.coefficient_layout, + ) + self.register_buffer("degree_index", degree_index, persistent=False) + self.weight = nn.Parameter( + torch.empty( + self.lmax + 1, + self.channels, + self.n_frames * self.channels, + dtype=dtype, + device=env.DEVICE, + ) + ) + bound = 1.0 / self.channels**0.5 + nn.init.uniform_(self.weight, -bound, bound, generator=get_generator(seed)) + for param in self.parameters(): + param.requires_grad = trainable + + def forward(self, coeff: torch.Tensor) -> torch.Tensor: + """Expand ``(N, D, F, C)`` coefficients to ``(N, D, F, K*C)``.""" + weight = self.weight.index_select(0, self.degree_index) + return torch.einsum("ndfi,dio->ndfo", coeff, weight) + + +class BaseGridNet(nn.Module): + """ + Shared implementation for S2 and SO(3) grid nets. + + ``mode='self'`` expects one input whose last channel axis contains two + branches. ``mode='cross'`` expects query and context inputs; the query side + is the source of attention queries and SwiGLU gates, while the context side + is the key/value or second product branch. + """ + + def __init__( + self, + *, + projector: BaseGridProjector, + channels: int, + n_focus: int, + mode: GridNetMode, + op_type: GridNetOp, + dtype: torch.dtype, + layout: GridNetLayout, + mlp_bias: bool, + trainable: bool, + grid_branches: int = 1, + frame_expand: nn.Module | None = None, + frame_contract: nn.Module | None = None, + residual_scale_init: float | None = None, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.projector = projector.to(device=env.DEVICE) + self.lmax = int(projector.lmax) + self.channels = int(channels) + self.n_focus = int(n_focus) + self.n_frames = int(projector.n_frames) + self.mode = str(mode).lower() + if self.mode not in {"self", "cross"}: + raise ValueError("`mode` must be either 'self' or 'cross'") + self.op_type = str(op_type).lower() + if self.op_type not in {"glu", "mlp", "branch"}: + raise ValueError("`op_type` must be one of 'glu', 'mlp', or 'branch'") + self.dtype = dtype + self.layout = str(layout).lower() + if self.layout not in {"ndfc", "nfdc", "flat"}: + raise ValueError("`layout` must be one of 'ndfc', 'nfdc', or 'flat'") + if self.mode == "self" and self.layout == "flat": + raise ValueError("`layout='flat'` is only supported for cross grid nets") + self.mlp_bias = bool(mlp_bias) + self.expanded_channels = self.n_frames * self.channels + self.frame_expand = frame_expand + self.frame_contract = frame_contract + self.query_channels = ( + 2 * self.expanded_channels + if self.mode == "self" + else ( + self.channels + if self.frame_expand is not None + else self.expanded_channels + ) + ) + self.context_channels = ( + self.channels if self.frame_expand is not None else self.expanded_channels + ) + self.output_channels = ( + self.channels if self.frame_contract is not None else self.expanded_channels + ) + self.frame_zero_index = int(getattr(projector, "frame_zero_index", 0)) + + self.scalar_act = SwiGLU() + self.scalar_gate = FocusLinear( + in_channels=2 * self.channels, + out_channels=self.channels, + n_focus=self.n_focus, + dtype=self.dtype, + bias=self.mlp_bias, + trainable=trainable, + seed=child_seed(seed, 0), + init_std=0.01, + ) + if self.op_type == "mlp": + self.grid_op = GridMLP( + channels=self.channels, + mode=self.mode, + dtype=self.dtype, + trainable=trainable, + seed=child_seed(seed, 1), + ) + elif self.op_type == "branch": + self.grid_op = GridBranch( + channels=self.channels, + n_branches=grid_branches, + dtype=self.dtype, + trainable=trainable, + seed=child_seed(seed, 1), + ) + else: + self.grid_op = nn.Identity() + + if residual_scale_init is None: + self.residual_scale = None + else: + self.residual_scale = nn.Parameter( + torch.ones( + self.n_focus, + self.output_channels, + dtype=self.dtype, + device=env.DEVICE, + ) + * float(residual_scale_init), + requires_grad=trainable, + ) + + def forward( + self, + query: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + """Apply the configured grid net and restore the input layout.""" + input_dtype = query.dtype + query_ndfc, shape_info = self._to_ndfc(query) + left, right, scalar_pair = self._prepare_pair(query_ndfc, context) + grid_out = self._apply_grid_op(left, right, scalar_pair) + coeff_out = self._from_grid(grid_out) + coeff_out = self._apply_scalar_path(coeff_out, scalar_pair) + coeff_out = self._contract_frames(coeff_out) + coeff_out = self._apply_residual_scale(coeff_out) + return self._restore_layout(coeff_out.to(dtype=input_dtype), shape_info) + + def _prepare_pair( + self, + query: torch.Tensor, + context: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.mode == "self": + return self._prepare_self_pair(query) + return self._prepare_cross_pair(query, context) + + def _prepare_self_pair( + self, + query: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + left, right = self._split_self_query(query) + scalar_pair = self._make_scalar_pair(left, right) + return left, right, scalar_pair + + def _prepare_cross_pair( + self, + query: torch.Tensor, + context: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if context is None: + raise ValueError("`context` is required when `mode='cross'`") + context_ndfc, _ = self._to_ndfc(context) + self._check_last_dim(query, self.context_channels, "query") + self._check_last_dim(context_ndfc, self.context_channels, "context") + if self.frame_expand is None: + scalar_pair = self._make_scalar_pair(query, context_ndfc) + return query, context_ndfc, scalar_pair + + scalar_pair = torch.cat( + [ + query[:, 0, :, :], + context_ndfc[:, 0, :, :], + ], + dim=-1, + ).to(dtype=self.dtype) + return ( + self.frame_expand(query), + self.frame_expand(context_ndfc), + scalar_pair, + ) + + def _apply_grid_op( + self, + left: torch.Tensor, + right: torch.Tensor, + scalar_pair: torch.Tensor, + ) -> torch.Tensor: + left_grid = self._to_grid(left.to(dtype=self.dtype)) + right_grid = self._to_grid(right.to(dtype=self.dtype)) + if self.op_type == "glu": + return left_grid * right_grid + if self.op_type == "mlp": + return self.grid_op(left_grid, right_grid) + return self.grid_op(left_grid, right_grid, scalar_pair) + + def _contract_frames(self, coeff: torch.Tensor) -> torch.Tensor: + if self.frame_contract is None: + return coeff + return self.frame_contract(coeff) + + def _apply_residual_scale(self, coeff: torch.Tensor) -> torch.Tensor: + if self.residual_scale is None: + return coeff + return coeff * self.residual_scale.reshape( + 1, + 1, + self.n_focus, + self.output_channels, + ) + + def _apply_scalar_path( + self, + coeff: torch.Tensor, + scalar_pair: torch.Tensor, + ) -> torch.Tensor: + scalar_out = self.scalar_act(scalar_pair) + scalar_gate = torch.sigmoid(self.scalar_gate(scalar_pair)) + n_batch, coeff_dim, n_focus, _ = coeff.shape + coeff_view = coeff.reshape( + n_batch, + coeff_dim, + n_focus, + self.n_frames, + self.channels, + ) + coeff_view = coeff_view * scalar_gate[:, None, :, None, :] + coeff_view[:, 0, :, self.frame_zero_index, :].add_(scalar_out) + return coeff_view.reshape(n_batch, coeff_dim, n_focus, self.expanded_channels) + + def _split_self_query( + self, query: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + self._check_last_dim(query, self.query_channels, "query") + return torch.chunk(query, chunks=2, dim=-1) + + def _make_scalar_pair( + self, left: torch.Tensor, right: torch.Tensor + ) -> torch.Tensor: + return torch.cat( + [ + self._extract_scalar(left), + self._extract_scalar(right), + ], + dim=-1, + ).to(dtype=self.dtype) + + def _extract_scalar(self, coeff: torch.Tensor) -> torch.Tensor: + n_batch, _, n_focus, _ = coeff.shape + coeff_view = coeff.reshape( + n_batch, + coeff.shape[1], + n_focus, + self.n_frames, + self.channels, + ) + return coeff_view[:, 0, :, self.frame_zero_index, :] + + def _to_grid(self, coeff: torch.Tensor) -> torch.Tensor: + n_batch, coeff_dim, n_focus, _ = coeff.shape + coeff_view = coeff.reshape( + n_batch, + coeff_dim, + n_focus, + self.n_frames, + self.channels, + ) + to_grid = self.projector.to_grid_mat.reshape( + self.projector.grid_size, + coeff_dim, + self.n_frames, + ) + return torch.einsum("gdk,ndfkc->ngfc", to_grid, coeff_view) + + def _from_grid(self, grid: torch.Tensor) -> torch.Tensor: + n_batch, _, n_focus, _ = grid.shape + coeff_dim = self.projector.coeff_dim // self.n_frames + from_grid = self.projector.from_grid_mat.reshape( + coeff_dim, + self.n_frames, + self.projector.grid_size, + ) + coeff = torch.einsum("dkg,ngfc->ndfkc", from_grid, grid) + return coeff.reshape(n_batch, coeff_dim, n_focus, self.expanded_channels) + + def _to_ndfc(self, value: torch.Tensor) -> tuple[torch.Tensor, tuple[int, ...]]: + if self.layout == "ndfc": + return value, tuple(value.shape) + if self.layout == "nfdc": + return value.transpose(1, 2), tuple(value.shape) + n_batch, coeff_dim, _ = value.shape + return ( + value.reshape(n_batch, coeff_dim, self.n_focus, -1), + tuple(value.shape), + ) + + def _restore_layout( + self, + value: torch.Tensor, + shape_info: tuple[int, ...], + ) -> torch.Tensor: + if self.layout == "ndfc": + return value + if self.layout == "nfdc": + return value.transpose(1, 2) + n_batch, coeff_dim, _ = shape_info + return value.reshape(n_batch, coeff_dim, -1) + + def _check_last_dim( + self, + value: torch.Tensor, + expected: int, + name: str, + ) -> None: + if value.shape[-1] != expected: + raise ValueError( + f"`{name}` last dimension must be {expected}, got {value.shape[-1]}" + ) + + +class S2GridNet(BaseGridNet): + """Grid net using an S2 spherical-harmonic projector.""" + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + channels: int, + n_focus: int = 1, + mode: GridNetMode, + op_type: GridNetOp, + dtype: torch.dtype, + layout: GridNetLayout, + grid_resolution_list: list[int] | None = None, + coefficient_layout: str = "packed", + grid_method: str = "e3nn", + grid_branches: int = 1, + residual_scale_init: float | None = None, + mlp_bias: bool = False, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + projector = S2GridProjector( + lmax=lmax, + mmax=mmax, + dtype=dtype, + grid_resolution_list=grid_resolution_list, + coefficient_layout=coefficient_layout, + grid_method=grid_method, + ) + self.grid_resolution_list = projector.grid_resolution_list + self.grid_method = projector.grid_method + super().__init__( + projector=projector, + channels=channels, + n_focus=n_focus, + mode=mode, + op_type=op_type, + dtype=dtype, + layout=layout, + mlp_bias=mlp_bias, + trainable=trainable, + grid_branches=grid_branches, + residual_scale_init=residual_scale_init, + seed=seed, + ) + + +class SO3GridNet(BaseGridNet): + """Grid net using a Wigner-D SO(3) projector with frame indices.""" + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + kmax: int = 1, + channels: int, + n_focus: int = 1, + mode: GridNetMode, + op_type: GridNetOp, + dtype: torch.dtype, + layout: GridNetLayout, + lebedev_precision: int | None = None, + coefficient_layout: str = "packed", + grid_branches: int = 1, + residual_scale_init: float | None = None, + mlp_bias: bool = False, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + projector = SO3GridProjector( + lmax=lmax, + mmax=mmax, + kmax=kmax, + dtype=dtype, + lebedev_precision=lebedev_precision, + coefficient_layout=coefficient_layout, + ) + self.frames = projector.frame_set + self.kmax = projector.kmax + self.lebedev_precision = projector.lebedev_precision + self.n_gamma = projector.n_gamma + frame_expand = None + frame_contract = None + if mode == "cross": + frame_expand = FrameExpand( + lmax=lmax, + mmax=projector.mmax, + coefficient_layout=coefficient_layout, + n_frames=projector.n_frames, + channels=channels, + dtype=dtype, + trainable=trainable, + seed=child_seed(seed, 4), + ) + frame_contract = FrameContract( + lmax=lmax, + mmax=projector.mmax, + coefficient_layout=coefficient_layout, + n_frames=projector.n_frames, + channels=channels, + dtype=dtype, + trainable=trainable, + seed=child_seed(seed, 5), + ) + super().__init__( + projector=projector, + channels=channels, + n_focus=n_focus, + mode=mode, + op_type=op_type, + dtype=dtype, + layout=layout, + mlp_bias=mlp_bias, + trainable=trainable, + grid_branches=grid_branches, + frame_expand=frame_expand, + frame_contract=frame_contract, + residual_scale_init=residual_scale_init, + seed=seed, + ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/indexing.py b/deepmd/pt/model/descriptor/sezm_nn/indexing.py index e550c9053b..4ff2077703 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/indexing.py +++ b/deepmd/pt/model/descriptor/sezm_nn/indexing.py @@ -76,6 +76,44 @@ def map_degree_idx(lmax: int, *, device: torch.device) -> torch.Tensor: ) +def build_gie_zonal_index( + lmax: int, *, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Build node-level packed indices for GIE zonal coupling. + + The returned tensors are aligned row-wise for every non-scalar packed + coefficient in the node representation. They select the local ``m=0`` column + of the matching degree from ``Dt_full`` or an equivalent zonal coupling table. + + Parameters + ---------- + lmax + Maximum node degree used by the geometric initial embedding. + device + Device for the returned tensors. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ``(node_row_index, node_zonal_m0_col_index, node_radial_l_index)``. + The first two index packed SO(3) rows/columns; the last one indexes + radial features with degree slots ``l=1..lmax`` stored as ``0..lmax-1``. + """ + lmax_i = int(lmax) + ebed_dim = get_so3_dim_of_lmax(lmax_i) + if lmax_i == 0: + empty = torch.empty(0, device=device, dtype=torch.long) + return empty, empty, empty + + packed_degree_by_row = map_degree_idx(lmax_i, device=device) + node_row_index = torch.arange(1, ebed_dim, device=device, dtype=torch.long) + node_degree_by_row = packed_degree_by_row[1:] + node_zonal_m0_col_index = node_degree_by_row * (node_degree_by_row + 1) + node_radial_l_index = node_degree_by_row - 1 + return node_row_index, node_zonal_m0_col_index, node_radial_l_index + + def project_D_to_m( D_full: torch.Tensor, coeff_index_m: torch.Tensor, diff --git a/deepmd/pt/model/descriptor/sezm_nn/projection.py b/deepmd/pt/model/descriptor/sezm_nn/projection.py new file mode 100644 index 0000000000..e4ec485b19 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/projection.py @@ -0,0 +1,594 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Grid projection helpers for SeZM function-space nonlinearities. + +The projectors in this module only handle basis transforms. They do not apply +channel mixing or nonlinearities. A projector maps coefficient tensors to a +fixed quadrature grid, and maps grid fields back to coefficients with the +matching quadrature rule. +""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + Any, +) + +import torch +import torch.nn as nn +from e3nn.o3 import ( + FromS2Grid, + ToS2Grid, + spherical_harmonics, +) + +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .indexing import ( + build_l_major_index, + build_m_major_index, + so3_packed_index, +) +from .lebedev import ( + LEBEDEV_PRECISION_TO_NPOINTS, + load_lebedev_rule, +) +from .wignerd import ( + WignerDCalculator, + build_edge_quaternion, + quaternion_multiply, + quaternion_z_rotation, +) + + +class BaseGridProjector(nn.Module): + """ + Base class for fixed coefficient-to-grid projection matrices. + + Subclasses build ``to_grid_mat`` with shape ``(G, J)`` and + ``from_grid_mat`` with shape ``(J, G)``, where ``G`` is the number of grid + samples and ``J`` is the flattened coefficient axis consumed by the grid + net. For ordinary S2 projections, ``J`` is the SO(3) feature coefficient + axis: ``D = (lmax + 1)^2`` in packed layout, or the retained ``D_m`` axis in + m-major layout. For SO(3) frame projections, ``J = D * n_frames`` with + frame index packed inside each coefficient row. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None, + dtype: torch.dtype, + n_frames: int, + coefficient_layout: str, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(self.lmax if mmax is None else mmax) + if self.mmax < 0: + raise ValueError("`mmax` must be non-negative") + if self.mmax > self.lmax: + raise ValueError("`mmax` must be <= `lmax`") + self.coefficient_layout = str(coefficient_layout).lower() + if self.coefficient_layout not in {"packed", "m_major"}: + raise ValueError( + "`coefficient_layout` must be either 'packed' or 'm_major'" + ) + self.dtype = dtype + self.device = torch.device("cpu") + self.precision = RESERVED_PRECISION_DICT[dtype] + self.n_frames = int(n_frames) + self.packed_dim = int((self.lmax + 1) ** 2) + + coeff_index = self._build_coefficient_index(device=torch.device("cpu")) + to_grid_mat, from_grid_mat = self._build_projection_mats(coeff_index) + self.coeff_dim = int(to_grid_mat.shape[1]) + self.grid_size = int(to_grid_mat.shape[0]) + if self.coeff_dim != int(from_grid_mat.shape[0]): + raise ValueError("Projection matrix coefficient axes `J` do not match") + if self.grid_size != int(from_grid_mat.shape[1]): + raise ValueError("Projection matrix grid axes `G` do not match") + self.register_buffer( + "to_grid_mat", + to_grid_mat.to(device=self.device, dtype=self.dtype), + persistent=False, + ) + self.register_buffer( + "from_grid_mat", + from_grid_mat.to(device=self.device, dtype=self.dtype), + persistent=False, + ) + + def to_grid(self, embedding: torch.Tensor) -> torch.Tensor: + """Project flattened coefficients ``(N, J, C)`` to grid fields ``(N, G, C)``.""" + to_grid_mat = self.to_grid_mat.to( + device=embedding.device, + dtype=embedding.dtype, + ) + return torch.einsum("gj,njc->ngc", to_grid_mat, embedding) + + def from_grid(self, grid: torch.Tensor) -> torch.Tensor: + """Project grid fields ``(N, G, C)`` back to flattened coefficients ``(N, J, C)``.""" + from_grid_mat = self.from_grid_mat.to( + device=grid.device, + dtype=grid.dtype, + ) + return torch.einsum("jg,ngc->njc", from_grid_mat, grid) + + def _build_coefficient_index(self, device: torch.device) -> torch.Tensor: + """Build the coefficient subset consumed by the projector matrices.""" + if self.coefficient_layout == "m_major": + return build_m_major_index(self.lmax, self.mmax, device=device) + if self.mmax == self.lmax: + return torch.arange((self.lmax + 1) ** 2, device=device, dtype=torch.long) + return build_l_major_index(self.lmax, self.mmax, device=device) + + def _build_projection_mats( + self, + coeff_index: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build ``to_grid_mat (G, J)`` and ``from_grid_mat (J, G)``.""" + raise NotImplementedError + + +class S2GridProjector(BaseGridProjector): + """ + Project SO(3) coefficients to/from a flattened S2 grid. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + mmax + Maximum order kept in the coefficient layout. If None, use ``lmax``. + dtype + Buffer dtype used by the projection matrices. + grid_resolution_list + Two-element resolution list. For ``grid_method='e3nn'`` it is + ``[R_phi, R_theta]`` and is converted to the ``e3nn`` + ``(lat, long) = (R_theta, R_phi)`` ordering. For + ``grid_method='lebedev'`` it is ``[precision, n_points]``. + coefficient_layout + Coefficient ordering expected by the caller: + - ``"packed"``: packed ``(l, m)`` order, optionally truncated by ``mmax``. + - ``"m_major"``: reduced m-major order used inside ``SO2Convolution``. + grid_method + S2 quadrature backend. Must be ``"e3nn"`` or ``"lebedev"``. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + dtype: torch.dtype, + grid_resolution_list: list[int] | None = None, + coefficient_layout: str = "packed", + grid_method: str = "e3nn", + ) -> None: + lmax_i = int(lmax) + mmax_i = int(lmax_i if mmax is None else mmax) + self.grid_method = str(grid_method).lower() + if self.grid_method not in {"e3nn", "lebedev"}: + raise ValueError("`grid_method` must be either 'e3nn' or 'lebedev'") + + self.grid_resolution_list = _normalize_s2_grid_resolution( + lmax_i, + mmax_i, + grid_resolution_list, + method=self.grid_method, + ) + if self.grid_method == "e3nn": + self.phi_resolution, self.theta_resolution = self.grid_resolution_list + self.lebedev_precision = 0 + self.lebedev_npoints = 0 + else: + self.phi_resolution = 0 + self.theta_resolution = 0 + self.lebedev_precision, self.lebedev_npoints = self.grid_resolution_list + + super().__init__( + lmax=lmax_i, + mmax=mmax_i, + dtype=dtype, + n_frames=1, + coefficient_layout=coefficient_layout, + ) + + def _rescale_truncated_orders(self, mat: torch.Tensor) -> None: + if self.lmax == self.mmax: + return + for degree in range(self.lmax + 1): + if degree <= self.mmax: + continue + start_idx = degree * degree + length = 2 * degree + 1 + rescale = math.sqrt(length / float(2 * self.mmax + 1)) + mat[:, :, start_idx : start_idx + length].mul_(rescale) + + def _rescale_truncated_matrix(self, mat: torch.Tensor) -> None: + if self.lmax == self.mmax: + return + for degree in range(self.lmax + 1): + if degree <= self.mmax: + continue + start_idx = degree * degree + length = 2 * degree + 1 + rescale = math.sqrt(length / float(2 * self.mmax + 1)) + mat[:, start_idx : start_idx + length].mul_(rescale) + + def _build_projection_mats( + self, + coeff_index: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.grid_method == "lebedev": + return self._build_lebedev_projection_mats(coeff_index) + return self._build_e3nn_projection_mats(coeff_index) + + def _build_e3nn_projection_mats( + self, + coeff_index: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + with torch.device("cpu"): + to_grid = ToS2Grid( + self.lmax, + (self.theta_resolution, self.phi_resolution), + normalization="component", + device="cpu", + ) + to_grid_mat = torch.einsum("mbi,am->bai", to_grid.shb, to_grid.sha).detach() + self._rescale_truncated_orders(to_grid_mat) + + from_grid = FromS2Grid( + (self.theta_resolution, self.phi_resolution), + self.lmax, + normalization="component", + device="cpu", + ) + from_grid_mat = torch.einsum( + "am,mbi->bai", from_grid.sha, from_grid.shb + ).detach() + self._rescale_truncated_orders(from_grid_mat) + + to_grid_mat = to_grid_mat.flatten(0, 1).index_select(1, coeff_index) + from_grid_mat = ( + from_grid_mat.flatten(0, 1).permute(1, 0).index_select(0, coeff_index) + ) + return to_grid_mat, from_grid_mat + + def _build_lebedev_projection_mats( + self, + coeff_index: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + with torch.device("cpu"): + points, weights = load_lebedev_rule( + self.lebedev_precision, + dtype=torch.float64, + device=torch.device("cpu"), + ) + harmonics = spherical_harmonics( + list(range(self.lmax + 1)), + points, + normalize=True, + normalization="norm", + ) + # Match the component-normalized product-grid convention used by + # e3nn's ToS2Grid/FromS2Grid pair so both S2 backends are drop-in + # replacements for the same grid net. + scale = math.sqrt(float(self.lmax + 1)) + degree_factors = harmonics.new_tensor( + [ + float(2 * degree + 1) + for degree in range(self.lmax + 1) + for _ in range(2 * degree + 1) + ] + ) + to_grid_mat = harmonics / scale + from_grid_mat = harmonics * ( + weights[:, None] * scale * degree_factors[None, :] + ) + self._rescale_truncated_matrix(to_grid_mat) + self._rescale_truncated_matrix(from_grid_mat) + + to_grid_mat = to_grid_mat.index_select(1, coeff_index) + from_grid_mat = from_grid_mat.index_select(1, coeff_index).transpose(0, 1) + return to_grid_mat, from_grid_mat + + def serialize(self) -> dict[str, Any]: + return { + "@class": "S2GridProjector", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "grid_resolution_list": self.grid_resolution_list, + "coefficient_layout": self.coefficient_layout, + "grid_method": self.grid_method, + }, + "@variables": {}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> S2GridProjector: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "S2GridProjector": + raise ValueError(f"Invalid class for S2GridProjector: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + data.pop("@variables", None) + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + return cls(**config) + + +class SO3GridProjector(BaseGridProjector): + """ + Project SO(3) coefficients to/from a Wigner-D grid with frame indices. + + The coefficient axis is packed as ``(l, m, k)`` with ordinary SeZM + ``(l, m)`` order outside and the configured frame set inside each row. A + frame index outside ``[-l, l]`` is kept as a zero column/row. This keeps the + tensor layout regular while preserving the exact per-degree frame support. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + kmax: int = 1, + dtype: torch.dtype, + lebedev_precision: int | None = None, + coefficient_layout: str = "packed", + ) -> None: + lmax_i = int(lmax) + mmax_i = int(lmax_i if mmax is None else mmax) + self.kmax = int(kmax) + if self.kmax < 0: + raise ValueError("`kmax` must be non-negative") + self.frame_set = _build_so3_frame_set(self.kmax) + self.frame_zero_index = self.frame_set.index(0) + self.lebedev_precision, self.lebedev_npoints, self.n_gamma = resolve_so3_grid( + lmax_i, + kmax=self.kmax, + lebedev_precision=lebedev_precision, + ) + super().__init__( + lmax=lmax_i, + mmax=mmax_i, + dtype=dtype, + n_frames=len(self.frame_set), + coefficient_layout=coefficient_layout, + ) + self.register_buffer( + "frame_values", + torch.tensor(self.frame_set, dtype=torch.long, device=self.device), + persistent=False, + ) + + def _build_projection_mats( + self, + coeff_index: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + with torch.device("cpu"): + points, weights = load_lebedev_rule( + self.lebedev_precision, + dtype=torch.float64, + device=torch.device("cpu"), + ) + gamma = torch.arange( + self.n_gamma, dtype=torch.float64, device=points.device + ) * (2.0 * math.pi / float(self.n_gamma)) + edge_quaternion = build_edge_quaternion(points, eps=1e-14) + edge_quaternion = edge_quaternion.repeat_interleave(self.n_gamma, dim=0) + gamma_quaternion = quaternion_z_rotation(gamma).repeat(points.shape[0], 1) + grid_quaternion = quaternion_multiply(gamma_quaternion, edge_quaternion) + wigner_grid, _ = WignerDCalculator(self.lmax, dtype=torch.float64).to( + torch.device("cpu") + )(grid_quaternion) + # ``build_edge_quaternion`` follows SeZM's global-to-local convention. + # The transpose below stores the local m=0 column in the same layout + # as ``WignerDCalculator.forward_zonal`` and extends it to k != 0. + wigner_grid = wigner_grid.transpose(-1, -2).contiguous() + haar_weight = weights.repeat_interleave(self.n_gamma) / float(self.n_gamma) + + grid_size = int(grid_quaternion.shape[0]) + coeff_dim = int(coeff_index.numel() * len(self.frame_set)) + to_grid_mat = torch.zeros( + grid_size, + coeff_dim, + dtype=torch.float64, + device=points.device, + ) + from_grid_mat = torch.zeros( + coeff_dim, + grid_size, + dtype=torch.float64, + device=points.device, + ) + + for degree in range(self.lmax + 1): + degree_factor = float(2 * degree + 1) + for m_order in range(-degree, degree + 1): + packed_idx = so3_packed_index(degree, m_order) + coeff_positions = (coeff_index == packed_idx).nonzero( + as_tuple=False + ) + if coeff_positions.numel() == 0: + continue + coeff_pos = int(coeff_positions[0, 0]) + for frame_pos, frame_order in enumerate(self.frame_set): + flat_idx = coeff_pos * len(self.frame_set) + frame_pos + if abs(frame_order) > degree: + continue + row = so3_packed_index(degree, m_order) + col = so3_packed_index(degree, frame_order) + values = wigner_grid[:, row, col] + to_grid_mat[:, flat_idx] = values + from_grid_mat[flat_idx, :] = ( + degree_factor * haar_weight * values + ) + return to_grid_mat, from_grid_mat + + def serialize(self) -> dict[str, Any]: + return { + "@class": "SO3GridProjector", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "kmax": self.kmax, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "lebedev_precision": self.lebedev_precision, + "coefficient_layout": self.coefficient_layout, + }, + "@variables": {}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SO3GridProjector: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SO3GridProjector": + raise ValueError(f"Invalid class for SO3GridProjector: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + data.pop("@variables", None) + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + return cls(**config) + + +def resolve_s2_grid_resolution( + lmax: int, + mmax: int, + *, + method: str = "e3nn", +) -> list[int]: + """ + Resolve the default S2 grid resolution. + + For ``method='e3nn'``, the automatic default uses even azimuthal sampling + ``R_phi = 2 * mmax + 4`` and even polar sampling + ``R_theta = ceil_even(3 * lmax + 2)``. + + For ``method='lebedev'``, the automatic default picks the smallest packaged + Lebedev rule whose algebraic precision is at least ``3 * lmax`` and returns + ``[precision, n_points]``. + """ + method = str(method).lower() + if method not in {"e3nn", "lebedev"}: + raise ValueError("`method` must be either 'e3nn' or 'lebedev'") + if method == "lebedev": + required_precision = 3 * int(lmax) + for precision, n_points in LEBEDEV_PRECISION_TO_NPOINTS.items(): + if precision >= required_precision: + return [precision, n_points] + raise ValueError( + f"No packaged Lebedev rule has precision >= {required_precision}" + ) + + phi_resolution = 2 * int(mmax) + 4 + theta_resolution = 3 * int(lmax) + 2 + theta_resolution += theta_resolution % 2 + return [phi_resolution, theta_resolution] + + +def resolve_so3_grid( + lmax: int, + *, + kmax: int = 1, + lebedev_precision: int | None = None, +) -> tuple[int, int, int]: + """ + Resolve the default SO(3) quadrature as Lebedev sphere times gamma samples. + + The Lebedev precision follows the same conservative ``3*lmax`` rule used by + the S2 grid path. The gamma grid is chosen for the quadratic grid products + used by the SO(3) grid nets, whose third-angle frequency can reach + ``k1 + k2 - kout``. + """ + lmax_i = int(lmax) + kmax_i = int(kmax) + if kmax_i < 0: + raise ValueError("`kmax` must be non-negative") + if lebedev_precision is None: + required_precision = 3 * lmax_i + for precision, n_points in LEBEDEV_PRECISION_TO_NPOINTS.items(): + if precision >= required_precision: + lebedev_precision = precision + lebedev_npoints = n_points + break + else: + raise ValueError( + f"No packaged Lebedev rule has precision >= {required_precision}" + ) + else: + lebedev_precision = int(lebedev_precision) + lebedev_npoints = LEBEDEV_PRECISION_TO_NPOINTS.get(lebedev_precision) + if lebedev_npoints is None: + raise ValueError( + f"Lebedev rule with precision {lebedev_precision} is not packaged" + ) + + # A quadratic product followed by analysis can contain gamma frequencies up + # to ``3*kmax``. A uniform grid with more samples than that frequency + # resolves the integer Fourier modes exactly. + n_gamma = 1 if kmax_i == 0 else 3 * kmax_i + 1 + return int(lebedev_precision), int(lebedev_npoints), int(n_gamma) + + +def _normalize_s2_grid_resolution( + lmax: int, + mmax: int, + grid_resolution_list: list[int] | None, + *, + method: str, +) -> list[int]: + """Resolve default grids or validate already-resolved low-level grids.""" + method = str(method).lower() + if grid_resolution_list is None: + return resolve_s2_grid_resolution(lmax, mmax, method=method) + if method == "lebedev": + if len(grid_resolution_list) != 2: + raise ValueError( + "Lebedev `grid_resolution_list` must be [precision, n_points]" + ) + precision = int(grid_resolution_list[0]) + n_points = int(grid_resolution_list[1]) + expected_n_points = LEBEDEV_PRECISION_TO_NPOINTS.get(precision) + if expected_n_points != n_points: + raise ValueError( + "Lebedev `grid_resolution_list` must match a packaged " + f"[precision, n_points] pair; got [{precision}, {n_points}]" + ) + return [precision, n_points] + + if len(grid_resolution_list) != 2: + raise ValueError("`grid_resolution_list` must contain two integers") + resolution = [int(grid_resolution_list[0]), int(grid_resolution_list[1])] + if resolution[0] < 1 or resolution[1] < 1: + raise ValueError("grid resolutions must be positive") + return resolution + + +def _build_so3_frame_set(kmax: int) -> list[int]: + """Build the symmetric frame-index set with zero first.""" + kmax_i = int(kmax) + if kmax_i < 0: + raise ValueError("`kmax` must be non-negative") + return [0, *[frame for kk in range(1, kmax_i + 1) for frame in (-kk, kk)]] diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py index ab0bd84f05..efa175dcef 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/so2.py +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -38,8 +38,6 @@ from .activation import ( GatedActivation, - SwiGLUS2Activation, - resolve_s2_grid_resolution, ) from .attention import ( segment_envelope_gated_softmax, @@ -47,6 +45,10 @@ from .attn_res import ( DepthAttnRes, ) +from .grid_net import ( + S2GridNet, + SO3GridNet, +) from .indexing import ( build_m_major_index, build_m_major_l_index, @@ -60,6 +62,9 @@ ReducedEquivariantRMSNorm, ScalarRMSNorm, ) +from .projection import ( + resolve_s2_grid_resolution, +) from .so3 import ( ChannelLinear, FocusLinear, @@ -710,6 +715,8 @@ class SO2Convolution(nn.Module): Maximum degree. mmax Maximum SO(2) order (|m|). If None, defaults to lmax. + kmax + Maximum Wigner-D frame order (|k|) used by SO(3) grid branches. channels Number of channels per (l, m) coefficient. n_focus @@ -754,6 +761,31 @@ class SO2Convolution(nn.Module): If True, replace each intermediate reduced-layout gate with S2-grid SwiGLU. Intermediate ``SO2Linear`` layers then output ``2 * focus_dim`` channels before the activation folds them back to ``focus_dim``. + node_wise_grid_mlp + If True, select the polynomial grid MLP operation for the node-wise + source-destination grid product. + node_wise_grid_branch + Number of scalar-routed polynomial product branches for the node-wise + grid product. ``0`` disables branch mixing; positive values take + precedence over ``node_wise_grid_mlp``. + message_node_grid_mlp + If True, select the polynomial grid MLP operation for the message-node + grid product. + message_node_grid_branch + Number of scalar-routed polynomial product branches for the + message-node grid product. ``0`` disables branch mixing; positive + values take precedence over ``message_node_grid_mlp``. + node_wise_s2 + If True, add an edge-local S2 product branch between radial-fused source + features and destination features in the same edge frame. + node_wise_so3 + If True, use the corresponding edge-local SO(3) Wigner-D grid branch. + message_node_s2 + If True, add a packed-layout S2 product branch between the aggregated + hidden message and the destination node features before ``post_focus_mix``. + message_node_so3 + If True, use the corresponding post-aggregation SO(3) Wigner-D grid + branch. lebedev_quadrature If True, use Lebedev quadrature for the S2 projector. activation_function @@ -788,6 +820,7 @@ def __init__( *, lmax: int, mmax: int | None = None, + kmax: int = 1, channels: int, n_focus: int = 1, focus_dim: int = 0, @@ -801,6 +834,14 @@ def __init__( atten_v_proj: bool = False, atten_o_proj: bool = False, s2_activation: bool = False, + node_wise_grid_mlp: bool = False, + node_wise_grid_branch: int = 0, + message_node_grid_mlp: bool = False, + message_node_grid_branch: int = 0, + node_wise_s2: bool = False, + node_wise_so3: bool = False, + message_node_s2: bool = False, + message_node_so3: bool = False, lebedev_quadrature: bool = False, activation_function: str = "silu", mlp_bias: bool = False, @@ -819,6 +860,9 @@ def __init__( raise ValueError("`mmax` must be non-negative") if self.mmax > self.lmax: raise ValueError("`mmax` must be <= `lmax`") + self.kmax = int(kmax) + if self.kmax < 0: + raise ValueError("`kmax` must be non-negative") self.channels = int(channels) self.n_focus = int(n_focus) if self.n_focus < 1: @@ -848,6 +892,16 @@ def __init__( self.use_atten_v_proj = bool(atten_v_proj) self.use_atten_o_proj = bool(atten_o_proj) self.s2_activation = bool(s2_activation) + self.node_wise_grid_mlp = bool(node_wise_grid_mlp) + self.node_wise_grid_branch = int(node_wise_grid_branch) + self.message_node_grid_mlp = bool(message_node_grid_mlp) + self.message_node_grid_branch = int(message_node_grid_branch) + if min(self.node_wise_grid_branch, self.message_node_grid_branch) < 0: + raise ValueError("grid branch counts must be non-negative") + self.node_wise_s2 = bool(node_wise_s2) + self.node_wise_so3 = bool(node_wise_so3) + self.message_node_s2 = bool(message_node_s2) + self.message_node_so3 = bool(message_node_so3) self.lebedev_quadrature = bool(lebedev_quadrature) self.s2_grid_method = "lebedev" if self.lebedev_quadrature else "e3nn" self.s2_grid_resolution = resolve_s2_grid_resolution( @@ -855,6 +909,16 @@ def __init__( self.mmax, method=self.s2_grid_method, ) + base_full_grid_resolution = resolve_s2_grid_resolution( + self.lmax, + self.lmax, + method=self.s2_grid_method, + ) + self.s2_full_grid_resolution = ( + [max(base_full_grid_resolution), max(base_full_grid_resolution)] + if self.s2_grid_method == "e3nn" + else base_full_grid_resolution + ) self.activation_function = str(activation_function) if self.n_atten_head < 0: raise ValueError("`n_atten_head` must be non-negative") @@ -928,6 +992,8 @@ def __init__( seed_depth_attn = child_seed(seed, 5) seed_radial_hidden = child_seed(seed, 6) seed_radial_degree = child_seed(seed, 7) + seed_node_wise_s2 = child_seed(seed, 8) + seed_message_node_s2 = child_seed(seed, 9) # === Step 3. Multiple SO2Linear layers === self.so2_linears = nn.ModuleList( @@ -977,12 +1043,14 @@ def __init__( for i in range(max(0, self.so2_layers - 1)): if self.s2_activation: non_linearities.append( - SwiGLUS2Activation( + S2GridNet( lmax=self.lmax, mmax=self.mmax, channels=self.so2_focus_dim, - dtype=self.compute_dtype, n_focus=self.n_focus, + mode="self", + op_type="glu", + dtype=self.compute_dtype, layout="nfdc", grid_resolution_list=self.s2_grid_resolution, coefficient_layout="m_major", @@ -1229,6 +1297,94 @@ def __init__( seed=seed_radial_degree, trainable=trainable, ) + node_wise_op = ( + "branch" + if self.node_wise_grid_branch > 0 + else ("mlp" if self.node_wise_grid_mlp else "glu") + ) + node_wise_branches = max(1, self.node_wise_grid_branch) + message_node_op = ( + "branch" + if self.message_node_grid_branch > 0 + else ("mlp" if self.message_node_grid_mlp else "glu") + ) + message_node_branches = max(1, self.message_node_grid_branch) + self.node_wise_grid_product: S2GridNet | SO3GridNet | None = None + if self.node_wise_s2 or self.node_wise_so3: + if self.node_wise_so3: + self.node_wise_grid_product = SO3GridNet( + lmax=self.lmax, + mmax=self.mmax, + kmax=self.kmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=node_wise_op, + dtype=self.compute_dtype, + layout="flat", + coefficient_layout="m_major", + grid_branches=node_wise_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + seed=seed_node_wise_s2, + ) + else: + self.node_wise_grid_product = S2GridNet( + lmax=self.lmax, + mmax=self.mmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=node_wise_op, + dtype=self.compute_dtype, + layout="flat", + grid_resolution_list=self.s2_grid_resolution, + grid_method=self.s2_grid_method, + grid_branches=node_wise_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + seed=seed_node_wise_s2, + ) + self.message_node_grid_product: S2GridNet | SO3GridNet | None = None + if self.message_node_s2 or self.message_node_so3: + if self.message_node_so3: + self.message_node_grid_product = SO3GridNet( + lmax=self.lmax, + kmax=self.kmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=message_node_op, + dtype=self.compute_dtype, + layout="flat", + coefficient_layout="packed", + grid_branches=message_node_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + seed=seed_message_node_s2, + ) + else: + self.message_node_grid_product = S2GridNet( + lmax=self.lmax, + mmax=self.lmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=message_node_op, + dtype=self.compute_dtype, + layout="flat", + grid_resolution_list=self.s2_full_grid_resolution, + grid_method=self.s2_grid_method, + grid_branches=message_node_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + coefficient_layout="packed", + seed=seed_message_node_s2, + ) # === Step 9. Pre-focus channel mixing === # This projects the full channel width before the SO(2) focus split. @@ -1291,6 +1447,7 @@ def forward( # === Step 2. Rotate to edge-aligned local frame === with nvtx_range("SO2Conv/rotate_to_local"): D_full = edge_cache.D_full + x_dst_local: torch.Tensor | None = None if self.use_triton_rotations and not self.training: x_local = rotate_to_local_triton( x=x, @@ -1300,6 +1457,15 @@ def forward( dim_full=self.ebed_dim_full, rotation_mode=self.triton_rotation_mode, ) # (E, D_m, C_wide) + if self.node_wise_grid_product is not None: + x_dst_local = rotate_to_local_triton( + x=x, + src=dst, + wigner=D_full, + coeff_index=self.coeff_index_m, + dim_full=self.ebed_dim_full, + rotation_mode=self.triton_rotation_mode, + ) # (E, D_m, C_wide) else: D_m_prime = project_D_to_m( D_full=D_full, @@ -1311,6 +1477,9 @@ def forward( ) x_src = x.index_select(0, src) # (E, D, C_wide) x_local = torch.bmm(D_m_prime, x_src) # (E, D_m, C_wide) + if self.node_wise_grid_product is not None: + x_dst = x.index_select(0, dst) # (E, D, C_wide) + x_dst_local = torch.bmm(D_m_prime, x_dst) # (E, D_m, C_wide) # === Step 3. Select radial/type features for reduced layout === with nvtx_range("SO2Conv/radial_fuse"): @@ -1321,6 +1490,11 @@ def forward( x_local.mul_(rad_feat) else: x_local = self.radial_degree_mixer(x_local, rad_feat) + if self.node_wise_grid_product is not None: + x_local = x_local + self.node_wise_grid_product( + x_local, + x_dst_local, + ) rad_feat_l0_focus = rad_feat[:, 0, :].reshape( n_edge, self.n_focus, self.so2_focus_dim ) # (E, F, Cf) @@ -1600,7 +1774,12 @@ def apply_bias_correction( n_node, self.ebed_dim_full, self.hidden_channels ).to(dtype=self.dtype) # (N, D, C_wide) - # === Step 9. Final channel mixing === + # === Step 9. Optional message-node grid product === + if self.message_node_grid_product is not None: + with nvtx_range("SO2Conv/message_node_grid"): + out = out + self.message_node_grid_product(out, x) + + # === Step 10. Final channel mixing === with nvtx_range("SO2Conv/post_focus_mix"): out = self.post_focus_mix(out.unsqueeze(2)).squeeze(2) return out # (N, D, C) @@ -1614,6 +1793,7 @@ def serialize(self) -> dict[str, Any]: "config": { "lmax": self.lmax, "mmax": self.mmax, + "kmax": self.kmax, "channels": self.channels, "n_focus": self.n_focus, "focus_dim": self.focus_dim, @@ -1627,6 +1807,14 @@ def serialize(self) -> dict[str, Any]: "atten_v_proj": self.use_atten_v_proj, "atten_o_proj": self.use_atten_o_proj, "s2_activation": self.s2_activation, + "node_wise_grid_mlp": self.node_wise_grid_mlp, + "node_wise_grid_branch": self.node_wise_grid_branch, + "message_node_grid_mlp": self.message_node_grid_mlp, + "message_node_grid_branch": self.message_node_grid_branch, + "node_wise_s2": self.node_wise_s2, + "node_wise_so3": self.node_wise_so3, + "message_node_s2": self.message_node_s2, + "message_node_so3": self.message_node_so3, "lebedev_quadrature": self.lebedev_quadrature, "activation_function": self.activation_function, "mlp_bias": self.mlp_bias, diff --git a/deepmd/pt/model/descriptor/sezm_nn/wignerd.py b/deepmd/pt/model/descriptor/sezm_nn/wignerd.py index ca90d6978c..f668e6b657 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/wignerd.py +++ b/deepmd/pt/model/descriptor/sezm_nn/wignerd.py @@ -134,33 +134,48 @@ class WignerSmallOrderCoefficients(nn.Module): """ Precomputed low-order quaternion polynomial kernels in the SeZM packed basis. - The tensors in this container provide the specialized ``l=2`` and ``l=3,4`` - kernels used by the hybrid Wigner runtime: - - ``C_l2`` stores the degree-4 tensor-contraction coefficients; - - ``C_l3`` / ``C_l4`` store flattened monomial coefficient matrices; + Only kernels required by the owning ``WignerDCalculator`` are registered: + + - ``C_l2`` stores the degree-4 tensor-contraction coefficients. + - ``C_l3`` .. ``C_l10`` store flattened monomial coefficient matrices. - ``C_combined_l3l4`` lifts the ``l=3`` basis to degree 8 and stacks it with - ``l=4`` so both blocks can be produced by one matrix multiply; - - ``exp_l3`` / ``exp_l4`` store the monomial exponent tables used by the runtime - gather/prod path. + ``l=4`` so both blocks can be produced by one matrix multiply. + - ``C_combined_l5l6`` applies the same degree-12 stacking for ``l=5,6``. + - ``C_combined_l7l8`` applies the same degree-16 stacking for ``l=7,8``. + - ``C_combined_l9l10`` applies the same degree-20 stacking for ``l=9,10``. + - ``exp_l3`` .. ``exp_l10`` store the monomial exponent tables used by the + runtime gather/prod path. """ + _EXTRA_KERNELS_BY_LMAX: ClassVar[tuple[tuple[int, tuple[str, ...]], ...]] = ( + (3, ("C_l3", "exp_l3")), + (4, ("C_l4", "C_combined_l3l4", "exp_l4")), + (5, ("C_l5", "exp_l5")), + (6, ("C_l6", "C_combined_l5l6", "exp_l6")), + (7, ("C_l7", "exp_l7")), + (8, ("C_l8", "C_combined_l7l8", "exp_l8")), + (9, ("C_l9", "exp_l9")), + (10, ("C_l10", "C_combined_l9l10", "exp_l10")), + ) + def __init__( self, *, - C_l2: torch.Tensor, - C_l3: torch.Tensor, - C_l4: torch.Tensor, - C_combined_l3l4: torch.Tensor, - exp_l3: torch.Tensor, - exp_l4: torch.Tensor, + lmax: int, + kernels: dict[str, torch.Tensor], ) -> None: super().__init__() - self.register_buffer("C_l2", C_l2, persistent=True) - self.register_buffer("C_l3", C_l3, persistent=True) - self.register_buffer("C_l4", C_l4, persistent=True) - self.register_buffer("C_combined_l3l4", C_combined_l3l4, persistent=True) - self.register_buffer("exp_l3", exp_l3, persistent=True) - self.register_buffer("exp_l4", exp_l4, persistent=True) + for name in self.required_kernel_names(lmax): + self.register_buffer(name, kernels[name], persistent=False) + + @classmethod + def required_kernel_names(cls, lmax: int) -> tuple[str, ...]: + """Return low-order kernel names required for ``lmax``.""" + names = ["C_l2"] + for threshold, extra_names in cls._EXTRA_KERNELS_BY_LMAX: + if lmax >= threshold: + names.extend(extra_names) + return tuple(names) def _safe_norm_nd(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: @@ -375,7 +390,10 @@ class WignerDCalculator(nn.Module): - ``l=1``: direct quaternion -> Cartesian rotation -> real l=1 block; - ``l=2``: dedicated degree-4 quaternion tensor contraction; - ``l=3,4``: dedicated quaternion monomial kernels; - - ``l>=5``: generic quaternion polynomial path with precomputed coefficient tables. + - ``l=5,6``: dedicated quaternion monomial kernels; + - ``l=7,8``: dedicated quaternion monomial kernels; + - ``l=9,10``: dedicated quaternion monomial kernels; + - ``l>=11``: generic quaternion polynomial path with precomputed coefficient tables. """ _SMALL_ORDER_CACHE_CPU_FP64: ClassVar[dict[str, torch.Tensor] | None] = None @@ -395,7 +413,7 @@ def __init__( self.device = env.DEVICE self.eps = float(eps) self.dim_full = (self.lmax + 1) ** 2 - self.poly_lmin = 5 + self.poly_lmin = 11 self.poly_offset = self.poly_lmin * self.poly_lmin self.register_buffer( @@ -412,6 +430,7 @@ def __init__( if self.lmax >= 2: self.small_order_kernels = self._build_small_order_kernels( + lmax=self.lmax, dtype=self.dtype, device=self.device, ) @@ -503,6 +522,36 @@ def forward( with nvtx_range("WignerD/l3"): D_full[:, 9:16, 9:16] = self._compute_l3_block(edge_quaternion) + if self.lmax >= 5: + if self.lmax >= 6: + with nvtx_range("WignerD/l5l6"): + D_l5, D_l6 = self._compute_l5l6_blocks(edge_quaternion) + D_full[:, 25:36, 25:36] = D_l5 + D_full[:, 36:49, 36:49] = D_l6 + else: + with nvtx_range("WignerD/l5"): + D_full[:, 25:36, 25:36] = self._compute_l5_block(edge_quaternion) + + if self.lmax >= 7: + if self.lmax >= 8: + with nvtx_range("WignerD/l7l8"): + D_l7, D_l8 = self._compute_l7l8_blocks(edge_quaternion) + D_full[:, 49:64, 49:64] = D_l7 + D_full[:, 64:81, 64:81] = D_l8 + else: + with nvtx_range("WignerD/l7"): + D_full[:, 49:64, 49:64] = self._compute_l7_block(edge_quaternion) + + if self.lmax >= 9: + if self.lmax >= 10: + with nvtx_range("WignerD/l9l10"): + D_l9, D_l10 = self._compute_l9l10_blocks(edge_quaternion) + D_full[:, 81:100, 81:100] = D_l9 + D_full[:, 100:121, 100:121] = D_l10 + else: + with nvtx_range("WignerD/l9"): + D_full[:, 81:100, 81:100] = self._compute_l9_block(edge_quaternion) + if self.lmax >= self.poly_lmin: with nvtx_range("WignerD/polynomial"): ra_re, ra_im, rb_re, rb_im = self._quaternion_to_ra_rb_real( @@ -533,33 +582,174 @@ def forward( Dt_full = D_full.transpose(-1, -2).contiguous() return D_full, Dt_full + def forward_zonal( + self, + edge_quaternion: torch.Tensor, + lmin: int = 1, + ) -> torch.Tensor: + """ + Build local ``m=0`` to global coupling for GIE. + + The returned layout matches the packed node rows for degrees + ``lmin..lmax``: each degree contributes ``2l+1`` values in packed + ``m=-l..l`` order. These values are equivalent to gathering + ``Dt_full[:, row(l, m), col(l, 0)]`` from :meth:`forward` over the + same degree range. + + Parameters + ---------- + edge_quaternion + Unit quaternions with shape ``(E, 4)`` representing the global->local + edge rotation. + lmin + First degree to return. + + Returns + ------- + torch.Tensor + Zonal coupling with shape + ``(E, (lmax + 1) ** 2 - lmin ** 2)``. + """ + lmin = int(lmin) + if lmin < 1: + raise ValueError("`lmin` must be >= 1") + n_edge = edge_quaternion.shape[0] + if self.lmax < lmin: + return torch.empty( + n_edge, + 0, + dtype=self.dtype, + device=edge_quaternion.device, + ) + edge_quaternion = quaternion_normalize( + edge_quaternion.to(dtype=self.dtype), + eps=self.eps, + ) + + with nvtx_range("WignerD/zonal"): + zonal_blocks: list[torch.Tensor] = [] + if lmin <= 1 <= self.lmax: + zonal_blocks.append(self._compute_l1_block(edge_quaternion)[:, 1, :]) + + if lmin <= 2 <= self.lmax: + zonal_blocks.append(self._compute_l2_block(edge_quaternion)[:, 2, :]) + + if self.lmax >= 3 and lmin <= 4: + if self.lmax >= 4: + D_l3, D_l4 = self._compute_l3l4_blocks(edge_quaternion) + if lmin <= 3: + zonal_blocks.append(D_l3[:, 3, :]) + zonal_blocks.append(D_l4[:, 4, :]) + else: + zonal_blocks.append( + self._compute_l3_block(edge_quaternion)[:, 3, :] + ) + + if self.lmax >= 5 and lmin <= 6: + if self.lmax >= 6: + D_l5, D_l6 = self._compute_l5l6_blocks(edge_quaternion) + if lmin <= 5: + zonal_blocks.append(D_l5[:, 5, :]) + zonal_blocks.append(D_l6[:, 6, :]) + else: + zonal_blocks.append( + self._compute_l5_block(edge_quaternion)[:, 5, :] + ) + + if self.lmax >= 7 and lmin <= 8: + if self.lmax >= 8: + D_l7, D_l8 = self._compute_l7l8_blocks(edge_quaternion) + if lmin <= 7: + zonal_blocks.append(D_l7[:, 7, :]) + zonal_blocks.append(D_l8[:, 8, :]) + else: + zonal_blocks.append( + self._compute_l7_block(edge_quaternion)[:, 7, :] + ) + + if self.lmax >= 9 and lmin <= 10: + if self.lmax >= 10: + D_l9, D_l10 = self._compute_l9l10_blocks(edge_quaternion) + if lmin <= 9: + zonal_blocks.append(D_l9[:, 9, :]) + zonal_blocks.append(D_l10[:, 10, :]) + else: + zonal_blocks.append( + self._compute_l9_block(edge_quaternion)[:, 9, :] + ) + + if self.lmax >= self.poly_lmin and lmin <= self.lmax: + ra_re, ra_im, rb_re, rb_im = self._quaternion_to_ra_rb_real( + edge_quaternion + ) + D_re, D_im = self._wigner_d_matrix_realpair( + ra_re, + ra_im, + rb_re, + rb_im, + self.poly_coeffs, + dtype=self.dtype, + ) + D_poly = self._wigner_d_pair_to_real( + D_re, + D_im, + ( + self.poly_u_re, + self.poly_u_im, + self.poly_u_re_t, + self.poly_u_im_t, + ), + lmax=self.lmax, + lmin=self.poly_lmin, + ) + poly_lmin = max(lmin, self.poly_lmin) + offset = 0 + for degree in range(self.poly_lmin, self.lmax + 1): + block_size = 2 * degree + 1 + block_end = offset + block_size + if degree >= poly_lmin: + zonal_blocks.append( + D_poly[:, offset + degree, offset:block_end] + ) + offset = block_end + + return torch.cat(zonal_blocks, dim=1) + @classmethod - def _get_small_order_cache_cpu_fp64(cls) -> dict[str, torch.Tensor]: - """Generate the low-order kernel coefficients once per process on CPU fp64.""" + def _get_small_order_cache_cpu_fp64(cls, lmax: int) -> dict[str, torch.Tensor]: + """Generate the required low-order kernel coefficients on CPU fp64.""" + target_lmax = min(max(int(lmax), 2), 10) if cls._SMALL_ORDER_CACHE_CPU_FP64 is None: - cls._SMALL_ORDER_CACHE_CPU_FP64 = cls._generate_small_order_cache_cpu_fp64() - return cls._SMALL_ORDER_CACHE_CPU_FP64 + cls._SMALL_ORDER_CACHE_CPU_FP64 = {} + cache = cls._SMALL_ORDER_CACHE_CPU_FP64 + required_names = WignerSmallOrderCoefficients.required_kernel_names(target_lmax) + if any(name not in cache for name in required_names): + cache.update(cls._generate_small_order_cache_cpu_fp64(target_lmax)) + return cache @classmethod def _build_small_order_kernels( cls, *, + lmax: int, dtype: torch.dtype, device: torch.device, ) -> WignerSmallOrderCoefficients: - """Instantiate the specialized ``l=2,3,4`` kernels on the requested device/dtype.""" - cache = cls._get_small_order_cache_cpu_fp64() + """Instantiate the specialized ``l=2..10`` kernels on the requested device/dtype.""" + cache = cls._get_small_order_cache_cpu_fp64(lmax) + kernels = {} + for name in WignerSmallOrderCoefficients.required_kernel_names(lmax): + if name.startswith("exp_"): + kernels[name] = cache[name].to(device=device) + else: + kernels[name] = cache[name].to(device=device, dtype=dtype) return WignerSmallOrderCoefficients( - C_l2=cache["C_l2"].to(device=device, dtype=dtype), - C_l3=cache["C_l3"].to(device=device, dtype=dtype), - C_l4=cache["C_l4"].to(device=device, dtype=dtype), - C_combined_l3l4=cache["C_combined_l3l4"].to(device=device, dtype=dtype), - exp_l3=cache["exp_l3"].to(device=device), - exp_l4=cache["exp_l4"].to(device=device), + lmax=lmax, + kernels=kernels, ) @classmethod - def _generate_small_order_cache_cpu_fp64(cls) -> dict[str, torch.Tensor]: + def _generate_small_order_cache_cpu_fp64(cls, lmax: int) -> dict[str, torch.Tensor]: """ Generate the low-order kernel coefficients from the generic SeZM reference path. @@ -567,65 +757,59 @@ def _generate_small_order_cache_cpu_fp64(cls) -> dict[str, torch.Tensor]: validated against the generic quaternion polynomial evaluator, and then reused by every `WignerDCalculator` instance. """ + target_lmax = min(max(int(lmax), 2), 10) dtype = torch.float64 device = torch.device("cpu") generator = torch.Generator() generator.manual_seed(20260404) - q_fit = torch.randn(2048, 4, dtype=dtype, device=device, generator=generator) + max_monomials = math.comb(2 * target_lmax + 3, 3) + n_fit = min(2048, max(128, 2 * max_monomials)) + q_fit = torch.randn(n_fit, 4, dtype=dtype, device=device, generator=generator) q_fit = quaternion_normalize(q_fit, eps=torch.finfo(dtype).eps) ref_blocks = cls._compute_generic_reference_blocks( - q_fit, lmax=4, dtype=dtype, device=device + q_fit, lmax=target_lmax, dtype=dtype, device=device ) - monomials_l2 = cls._generate_monomials(4, 4) - monomials_l3 = cls._generate_monomials(4, 6) - monomials_l4 = cls._generate_monomials(4, 8) - exp_l2 = cls._monomials_to_exponent_tensor(monomials_l2, device=device) - exp_l3 = cls._monomials_to_exponent_tensor(monomials_l3, device=device) - exp_l4 = cls._monomials_to_exponent_tensor(monomials_l4, device=device) - - C_l2_flat = cls._solve_monomial_coefficients( - q_fit, - ref_blocks[2], - exp_l2, - ) - C_l3 = cls._solve_monomial_coefficients(q_fit, ref_blocks[3], exp_l3) - C_l4 = cls._solve_monomial_coefficients(q_fit, ref_blocks[4], exp_l4) - C_l2 = cls._build_l2_contraction_tensor(C_l2_flat, monomials_l2) - C_combined_l3l4 = cls._build_combined_l3l4( - C_l3, C_l4, monomials_l3, monomials_l4 - ) + monomials: dict[int, list[tuple[int, int, int, int]]] = {} + exponents: dict[int, torch.Tensor] = {} + coefficients: dict[int, torch.Tensor] = {} + cache: dict[str, torch.Tensor] = {} - q_val = torch.randn(256, 4, dtype=dtype, device=device, generator=generator) - q_val = quaternion_normalize(q_val, eps=torch.finfo(dtype).eps) - ref_val = cls._compute_generic_reference_blocks( - q_val, lmax=4, dtype=dtype, device=device - ) - test_val = cls._evaluate_small_order_blocks( - q_val, - C_l2=C_l2, - C_l3=C_l3, - C_l4=C_l4, - exp_l3=exp_l3, - exp_l4=exp_l4, - ) - thresholds = {2: 1e-10, 3: 1e-10, 4: 1e-10} - for ell in (2, 3, 4): - err = (test_val[ell] - ref_val[ell]).abs().max().item() - if err > thresholds[ell]: - raise RuntimeError( - f"Failed to generate stable SeZM Wigner coefficients for l={ell}: max_err={err}" + for ell in range(2, target_lmax + 1): + monomials[ell] = cls._generate_monomials(4, 2 * ell) + exponents[ell] = cls._monomials_to_exponent_tensor( + monomials[ell], device=device + ) + coeff = cls._solve_monomial_coefficients( + q_fit, + ref_blocks[ell], + exponents[ell], + ) + if ell == 2: + cache["C_l2"] = cls._build_l2_contraction_tensor(coeff, monomials[2]) + else: + coefficients[ell] = coeff + cache[f"C_l{ell}"] = coeff + cache[f"exp_l{ell}"] = exponents[ell] + + combined_builders = { + 4: ("C_combined_l3l4", cls._build_combined_l3l4), + 6: ("C_combined_l5l6", cls._build_combined_l5l6), + 8: ("C_combined_l7l8", cls._build_combined_l7l8), + 10: ("C_combined_l9l10", cls._build_combined_l9l10), + } + for even_ell, (name, builder) in combined_builders.items(): + if target_lmax >= even_ell: + odd_ell = even_ell - 1 + cache[name] = builder( + coefficients[odd_ell], + coefficients[even_ell], + monomials[odd_ell], + monomials[even_ell], ) - return { - "C_l2": C_l2, - "C_l3": C_l3, - "C_l4": C_l4, - "C_combined_l3l4": C_combined_l3l4, - "exp_l3": exp_l3, - "exp_l4": exp_l4, - } + return cache @classmethod def _compute_generic_reference_blocks( @@ -636,7 +820,7 @@ def _compute_generic_reference_blocks( dtype: torch.dtype, device: torch.device, ) -> dict[int, torch.Tensor]: - """Evaluate the generic SeZM polynomial path and extract the ``l=2,3,4`` blocks.""" + """Evaluate the generic SeZM polynomial path and extract per-degree blocks.""" coeffs = cls._precompute_wigner_coefficients( lmax, dtype=dtype, @@ -665,11 +849,14 @@ def _compute_generic_reference_blocks( lmax=lmax, lmin=2, ) - return { - 2: D_ref[:, 0:5, 0:5], - 3: D_ref[:, 5:12, 5:12], - 4: D_ref[:, 12:21, 12:21], - } + ref_blocks: dict[int, torch.Tensor] = {} + offset = 0 + for ell in range(2, lmax + 1): + block_size = 2 * ell + 1 + block_end = offset + block_size + ref_blocks[ell] = D_ref[:, offset:block_end, offset:block_end] + offset = block_end + return ref_blocks @classmethod def _solve_monomial_coefficients( @@ -707,39 +894,6 @@ def _build_l2_contraction_tensor( C_l2[i, j, p0, p1, p2, p3] = share return C_l2 - @classmethod - def _evaluate_small_order_blocks( - cls, - edge_quaternion: torch.Tensor, - *, - C_l2: torch.Tensor, - C_l3: torch.Tensor, - C_l4: torch.Tensor, - exp_l3: torch.Tensor, - exp_l4: torch.Tensor, - ) -> dict[int, torch.Tensor]: - """Evaluate the specialized ``l=2,3,4`` kernels for validation and caching.""" - q2 = edge_quaternion.unsqueeze(-1) * edge_quaternion.unsqueeze(-2) - q4 = q2.unsqueeze(-1).unsqueeze(-1) * q2.unsqueeze(-3).unsqueeze(-3) - D_l2 = torch.einsum("nabcd,ijabcd->nij", q4, C_l2) - - powers6 = cls._precompute_powers(edge_quaternion, 6) - M3 = cls._build_monomial_matrix(powers6, exp_l3) - D_l3 = torch.matmul(M3, C_l3.transpose(0, 1)).view( - edge_quaternion.shape[0], 7, 7 - ) - - powers8 = cls._precompute_powers(edge_quaternion, 8) - M4 = cls._build_monomial_matrix(powers8, exp_l4) - D_l4 = torch.matmul(M4, C_l4.transpose(0, 1)).view( - edge_quaternion.shape[0], 9, 9 - ) - return { - 2: D_l2, - 3: D_l3, - 4: D_l4, - } - @staticmethod def _generate_monomials( n_vars: int, @@ -796,6 +950,81 @@ def _build_combined_l3l4( C_l3_lifted[:, mono8_to_idx[mono8]] += C_l3[:, j] return torch.cat([C_l3_lifted, C_l4], dim=0) + @staticmethod + def _build_combined_l5l6( + C_l5: torch.Tensor, + C_l6: torch.Tensor, + monomials_l5: list[tuple[int, int, int, int]], + monomials_l6: list[tuple[int, int, int, int]], + ) -> torch.Tensor: + """Lift the ``l=5`` basis to degree 12 and stack it with the ``l=6`` basis.""" + mono12_to_idx = {mono: idx for idx, mono in enumerate(monomials_l6)} + C_l5_lifted = torch.zeros( + C_l5.shape[0], + len(monomials_l6), + dtype=C_l5.dtype, + device=C_l5.device, + ) + for j, (a, b, c, d) in enumerate(monomials_l5): + for mono12 in ( + (a + 2, b, c, d), + (a, b + 2, c, d), + (a, b, c + 2, d), + (a, b, c, d + 2), + ): + C_l5_lifted[:, mono12_to_idx[mono12]] += C_l5[:, j] + return torch.cat([C_l5_lifted, C_l6], dim=0) + + @staticmethod + def _build_combined_l7l8( + C_l7: torch.Tensor, + C_l8: torch.Tensor, + monomials_l7: list[tuple[int, int, int, int]], + monomials_l8: list[tuple[int, int, int, int]], + ) -> torch.Tensor: + """Lift the ``l=7`` basis to degree 16 and stack it with the ``l=8`` basis.""" + mono16_to_idx = {mono: idx for idx, mono in enumerate(monomials_l8)} + C_l7_lifted = torch.zeros( + C_l7.shape[0], + len(monomials_l8), + dtype=C_l7.dtype, + device=C_l7.device, + ) + for j, (a, b, c, d) in enumerate(monomials_l7): + for mono16 in ( + (a + 2, b, c, d), + (a, b + 2, c, d), + (a, b, c + 2, d), + (a, b, c, d + 2), + ): + C_l7_lifted[:, mono16_to_idx[mono16]] += C_l7[:, j] + return torch.cat([C_l7_lifted, C_l8], dim=0) + + @staticmethod + def _build_combined_l9l10( + C_l9: torch.Tensor, + C_l10: torch.Tensor, + monomials_l9: list[tuple[int, int, int, int]], + monomials_l10: list[tuple[int, int, int, int]], + ) -> torch.Tensor: + """Lift the ``l=9`` basis to degree 20 and stack it with the ``l=10`` basis.""" + mono20_to_idx = {mono: idx for idx, mono in enumerate(monomials_l10)} + C_l9_lifted = torch.zeros( + C_l9.shape[0], + len(monomials_l10), + dtype=C_l9.dtype, + device=C_l9.device, + ) + for j, (a, b, c, d) in enumerate(monomials_l9): + for mono20 in ( + (a + 2, b, c, d), + (a, b + 2, c, d), + (a, b, c + 2, d), + (a, b, c, d + 2), + ): + C_l9_lifted[:, mono20_to_idx[mono20]] += C_l9[:, j] + return torch.cat([C_l9_lifted, C_l10], dim=0) + @staticmethod def _precompute_powers( q: torch.Tensor, @@ -880,6 +1109,99 @@ def _compute_l3l4_blocks( D_l4 = D_flat[:, 49:].view(edge_quaternion.shape[0], 9, 9) return D_l3, D_l4 + def _compute_l5_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: + """Compute the ``l=5`` block from the dedicated degree-10 monomial kernel.""" + powers = self._precompute_powers(edge_quaternion, 10) + monomials = self._build_monomial_matrix( + powers, + self.small_order_kernels.exp_l5, + ) + D_flat = torch.matmul( + monomials, + self.small_order_kernels.C_l5.transpose(0, 1), + ) + return D_flat.view(edge_quaternion.shape[0], 11, 11) + + def _compute_l5l6_blocks( + self, + edge_quaternion: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute the ``l=5`` and ``l=6`` blocks from one shared degree-12 kernel.""" + powers = self._precompute_powers(edge_quaternion, 12) + monomials = self._build_monomial_matrix( + powers, + self.small_order_kernels.exp_l6, + ) + D_flat = torch.matmul( + monomials, + self.small_order_kernels.C_combined_l5l6.transpose(0, 1), + ) + D_l5 = D_flat[:, :121].view(edge_quaternion.shape[0], 11, 11) + D_l6 = D_flat[:, 121:].view(edge_quaternion.shape[0], 13, 13) + return D_l5, D_l6 + + def _compute_l7_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: + """Compute the ``l=7`` block from the dedicated degree-14 monomial kernel.""" + powers = self._precompute_powers(edge_quaternion, 14) + monomials = self._build_monomial_matrix( + powers, + self.small_order_kernels.exp_l7, + ) + D_flat = torch.matmul( + monomials, + self.small_order_kernels.C_l7.transpose(0, 1), + ) + return D_flat.view(edge_quaternion.shape[0], 15, 15) + + def _compute_l7l8_blocks( + self, + edge_quaternion: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute the ``l=7`` and ``l=8`` blocks from one shared degree-16 kernel.""" + powers = self._precompute_powers(edge_quaternion, 16) + monomials = self._build_monomial_matrix( + powers, + self.small_order_kernels.exp_l8, + ) + D_flat = torch.matmul( + monomials, + self.small_order_kernels.C_combined_l7l8.transpose(0, 1), + ) + D_l7 = D_flat[:, :225].view(edge_quaternion.shape[0], 15, 15) + D_l8 = D_flat[:, 225:].view(edge_quaternion.shape[0], 17, 17) + return D_l7, D_l8 + + def _compute_l9_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: + """Compute the ``l=9`` block from the dedicated degree-18 monomial kernel.""" + powers = self._precompute_powers(edge_quaternion, 18) + monomials = self._build_monomial_matrix( + powers, + self.small_order_kernels.exp_l9, + ) + D_flat = torch.matmul( + monomials, + self.small_order_kernels.C_l9.transpose(0, 1), + ) + return D_flat.view(edge_quaternion.shape[0], 19, 19) + + def _compute_l9l10_blocks( + self, + edge_quaternion: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute the ``l=9`` and ``l=10`` blocks from one shared degree-20 kernel.""" + powers = self._precompute_powers(edge_quaternion, 20) + monomials = self._build_monomial_matrix( + powers, + self.small_order_kernels.exp_l10, + ) + D_flat = torch.matmul( + monomials, + self.small_order_kernels.C_combined_l9l10.transpose(0, 1), + ) + D_l9 = D_flat[:, :361].view(edge_quaternion.shape[0], 19, 19) + D_l10 = D_flat[:, 361:].view(edge_quaternion.shape[0], 21, 21) + return D_l9, D_l10 + @staticmethod def _factorial_table( n: int, dtype: torch.dtype, device: torch.device diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 08a70bfd93..1a01b05fe9 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -307,6 +307,7 @@ def get_sezm_model(model_params: dict) -> BaseModel: model_params = copy.deepcopy(model_params) model_params.setdefault("descriptor", {}) model_params.setdefault("fitting_net", {}) + model_params["descriptor"].setdefault("type", "dpa4") ntypes = len(model_params["type_map"]) model_params["descriptor"]["ntypes"] = ntypes @@ -378,6 +379,7 @@ def get_sezm_spin_model(model_params: dict) -> BaseModel: model_params = copy.deepcopy(model_params) model_params.setdefault("descriptor", {}) model_params.setdefault("fitting_net", {}) + model_params["descriptor"].setdefault("type", "dpa4") _normalize_spin_use_spin(model_params) real_sel = model_params["descriptor"].get("sel", 120) real_sel_list = [int(real_sel)] if isinstance(real_sel, int) else list(real_sel) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index c5aa7f70ca..44de7260a0 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -1,14 +1,23 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """SeZM: Smooth equivariant Zone-bridging Model. -This module hosts the full ``torch.compile`` + ``make_fx`` pipeline that -runs the SeZM energy (``ener``) path on the GPU. To the authors' -knowledge this is the first public implementation of a compiled, -dynamically shaped machine-learning potential whose *second-order* -derivatives -- required by force-loss training -- travel end-to-end -through Inductor without any eager fallback. The ``dens`` path below -uses a plain ``torch.compile`` wrapper and is not covered by the rest of -this docstring. +This module hosts the ``make_fx`` + Inductor pipeline that runs the SeZM +energy (``ener``) path on the GPU. To the authors' knowledge this is the +first public implementation of a compiled, dynamically shaped +machine-learning potential whose *second-order* derivatives -- required by +force-loss training -- travel end-to-end through Inductor without any +eager fallback. + +After ``make_fx`` captures the graph the two modes diverge at the backend. +**Training** lowers it with ``torch.compile``: the Dynamo frontend builds +the optimizer's second backward through the already-materialised first +derivative. **Inference** lowers the same graph with +``aot_module_simplified`` -- AOTAutograd's forward-only path -- which skips +the Dynamo frontend entirely. Dynamo's shape-guard production aborts on +the forward-only graph (``sources must not be empty for symbol s...``), +and its activation handling costs ~3x the peak memory; the AOTAutograd +inference path avoids both. The ``dens`` path below uses a plain +``torch.compile`` wrapper and is not covered by the rest of this docstring. Why force-loss training is hard to compile ========================================== @@ -60,17 +69,17 @@ | | | tracing_mode="symbolic", | | | _allow_non_fake_inputs=True, | | | decomposition_table=) (NOTE 0) - | | | * trace inputs are nf=2 copies (NOTE 1) + | | | * trace inputs use safe prime dims (NOTE 1) | | | * silu_backward is decomposed (NOTE 2) | | | * traced graph already contains the | | | first autograd.grad over coords | | |-- _strip_saved_tensor_detach (train only) (NOTE 3) - | | |-- _rebuild_graph_module (NOTE 4) - | | '-- torch.compile(backend="inductor", - | | dynamic=True, - | | options=) (NOTE 6) + | | |-- _rebuild_graph_module (train only) (NOTE 4) + | | |-- train: torch.compile(backend="inductor", + | | | dynamic=True, options=) (NOTE 6) + | | '-- eval: aot_module_simplified (forward-only) (NOTE 13) | | stored in compiled_core_compute_cache[key] (NOTE 8) - | '-- compiled_core_compute_cache[key](...) + | '-- compiled_core_compute_cache[key](...) under no_grad in eval '-- communicate_extended_output Subsequent batches look up the cached callable at the same @@ -124,24 +133,22 @@ become symbolic immediately after the first op, so only the control flow is decided by concrete values. -NOTE 1 -- Tracing with ``nf=2`` -------------------------------- +NOTE 1 -- Trace inputs with safe prime dimensions +------------------------------------------------- ``make_fx(tracing_mode="symbolic")`` replaces tensor shapes with sympy -symbols at trace time, but the moment a symbolic dim ends up equal to a -concrete dim elsewhere in the same tensor it collapses into a constant. -Concretely: - -* ``nf=1`` triggers PyTorch's 0/1 specialization and bakes ``nf`` into - the graph. -* ``nf=3`` collides with the spatial ``3`` in ``extended_coord`` whose - shape is ``(nf, nall, 3)``. -* ``nf=9`` would collide with the virial dim. - -Any of those collisions forces ``torch.compile(dynamic=True)`` to reject -later batches whose ``nf`` differs from the traced constant. ``nf=2`` -is the smallest batch size free of every known collision; we always -repeat the first frame twice to satisfy this invariant during tracing. +symbols at trace time. If two input dimensions happen to share the same +concrete value, PyTorch may assign them one duck-shaped symbol and bake a +false equality such as ``nloc == nall`` into the graph. Dimensions that +match internal literals are also fragile: ``1`` triggers 0/1 +specialization, while ``2`` / ``3`` / ``9`` commonly appear as +charge-spin, Cartesian-coordinate, and virial widths. + +Before tracing, SeZM pads or trims real batch tensors so ``nf``, ``nloc`` +and ``nall`` become pairwise-distinct primes >= 5 that do not collide +with fixed model dimensions. ``nlist`` and ``mapping`` are clamped after +the shape coercion so their index values stay valid for the trace-only +sample. NOTE 2 -- Decomposing ``silu_backward`` --------------------------------------- @@ -219,8 +226,10 @@ NOTE 6 -- Inductor / Triton option lockdown ------------------------------------------- -``torch.compile(backend="inductor", dynamic=True, options=...)`` is -configured with: +One Inductor option set governs both backends: ``torch.compile`` takes it +as ``options=`` for training, and the eval path (NOTE 13) applies it via +``torch._inductor.config.patch`` around ``compile_fx_inner``, adding +``triton.max_tiles=1``. The options are: * ``max_autotune=False`` Autotune regresses on dynamic shapes because each recompile rolls @@ -287,10 +296,21 @@ encounter of each mode pays the compile cost once, and every later toggle is a dict lookup. +Multi-task runs add one module-level sharing layer on top of this +per-instance cache. Tasks whose descriptor and fitting parameters are +the same Python objects after ``share_params(level=0)`` reuse a single +compiled callable. Per-task tensors that must remain distinct +(``out_bias``, ``out_std``, ``bias_atom_e`` and ``case_embd``) are +promoted to explicit FX placeholders, so the shared graph reads their +current values at each call instead of baking the first task's tensors as +constants. + Enabling compile for eval is an opt-in via ``DP_COMPILE_INFER=1`` (``should_use_compile`` returns ``_env_use_compile_infer`` when ``self.training`` is ``False``). Once enabled, regular validation, full validation and EMA full validation all reuse the eval slot. +Eval TF32 is separately controlled by ``DP_TF32_INFER``: +``0 -> highest``, ``1 -> high``, ``2 -> medium``. NOTE 8 -- Storing the compile cache outside the ``nn.Module`` tree ------------------------------------------------------------------ @@ -320,8 +340,10 @@ ``dE/dx`` against a graph of known shape and ownership -- the essential precondition for make_fx symbolic tracing. -In eval mode we merely detach; no ``create_graph`` is requested, so the -compiled kernel never has to build a backward graph. +In eval the rebound coordinate still requires grad for the eager +(non-compiled) path's ``autograd.grad``, but the compiled callable runs +under ``torch.no_grad`` so its AOTAutograd inference lowering builds no +outer backward (NOTE 13). NOTE 10 -- Tail dummy edges --------------------------- @@ -359,6 +381,53 @@ the outer optimizer's ``.backward()`` can continue walking it into the parameters. When ``False`` the double-backward graph is never built, saving memory during inference. + +NOTE 13 -- Inference lowering through ``aot_module_simplified`` +--------------------------------------------------------------- + +Training lowers the traced graph with ``torch.compile`` so the Dynamo +frontend can build the optimizer's second backward. Inference needs no +such backward, and routing it through Dynamo is actively harmful: + +* Dynamo re-traces the already-symbolic ``make_fx`` graph and re-runs + shape-guard production. On the forward-only graph one intermediate + view's extended-atom (``nall``) axis becomes a backed symbol with no + input source, and ``produce_guards`` aborts with ``sources must not be + empty for symbol s...``. +* Even when it compiles, the grad-bearing input makes AOTAutograd treat + the call as forward+backward and keep the whole forward activation set + alive -- ~3x the eager peak memory, OOM-ing on large inference sweeps. + +So eval lowers the graph with ``aot_module_simplified`` -- AOTAutograd's +inference path -- with no Dynamo frontend. It still functionalizes the +graph (in-place ops become out-of-place, so Inductor reuses buffers and +peak memory matches eager), and compiling under ``torch.no_grad`` selects +the forward-only partition. The make_fx placeholders' fake values carry +the symbolic sizes, so the single lowered artifact stays dynamic across +``nframes`` / ``nall`` / edge count without recompiles. + +Four details make the hand-wired AOTAutograd + Inductor path work: + +* **Flat output.** AOTAutograd rejects the dict ``core_compute`` + returns, so the graph output is rewritten to a tuple of the dict + values and the dict is re-packed around the compiled callable. +* **Decomposition table.** ``select_decomp_table()`` is passed so the + decomposition set matches Inductor's fallback set; the default + core-aten table clashes (``both a fallback and a decomp for same op: + aten._to_copy.default``). +* **Compile-time default device.** ``aot_module_simplified`` enters a + ``PhiloxStateTracker`` that allocates an RNG-state tensor without an + explicit device, so the compile runs under ``torch.device(model + device)`` to keep it off any stray ambient default. +* **1D pointwise grids.** ``triton.max_tiles=1`` keeps the + data-dependent edge axis on Triton's ``x`` grid (limit ``2**31``); the + default tiling places it on the ``y``/``z`` grid (limit ``65535``), + which overflows past ~2e4 atoms with ``CUDA error: invalid argument``. + +The random local-Z roll (``random_gamma``) is gated to training in the +descriptor, so the inference graph holds no ``aten.rand`` at all: the +model is roll-equivariant, which makes the roll a pure training +augmentation and inference deterministic. """ from __future__ import ( @@ -370,6 +439,7 @@ import time from contextlib import ( contextmanager, + nullcontext, ) from typing import ( TYPE_CHECKING, @@ -418,6 +488,10 @@ from deepmd.pt.utils.nlist import ( extend_input_and_build_neighbor_list, ) +from deepmd.pt.utils.nv_nlist import ( + NvNeighborList, + is_nv_available, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -426,6 +500,10 @@ SeZMModel_ = make_model(SeZMAtomicModel) +# Local-atom count above which the O(N) Toolkit-Ops cell list replaces the dense +# all-pairs builder for periodic CUDA systems. +SEZM_NV_NLIST_THRESHOLD = 1024 + # NOTE: Silence Inductor / Triton autotune dumps before any submodule is # imported. ``torch.compile`` reads these environment variables exactly # once at backend initialisation; setting them after the first compile @@ -453,59 +531,113 @@ # Multi-task compile sharing # --------------------------------------------------------------------------- # Maps (structure_key..., training, do_atomic_virial, has_coord_corr) to the -# compiled callable. Tasks whose descriptor AND fitting-net first child have -# the same Python-object identity (after share_params) reuse a single compiled -# graph, avoiding Nx compile-cache OOM and N DDP graph boundaries (NCCL timeout). -_SEZM_COMPILE_CACHE: dict[tuple, Any] = {} +# compiled callable. Tasks whose descriptor and fitting parameters share the +# same Python-object identity after ``share_params(level=0)`` reuse one compiled +# graph, avoiding N x compile-cache growth and duplicated DDP graph boundaries. +_SEZM_COMPILE_CACHE: dict[tuple[Any, ...], Any] = {} # Maps structure_key -> task_buf_order so every instance in the same group # knows which buffers were promoted and in what order. -_SEZM_TASK_BUF_ORDER: dict[tuple[int, ...], tuple[str, ...]] = {} +_SEZM_TASK_BUF_ORDER: dict[tuple[Any, ...], tuple[str, ...]] = {} # Prefix namespace for promoted buffer names. -_AM_PREFIX = "am/" # atomic_model registered buffer -_FIT_PREFIX = "fit/" # fitting_net registered buffer -_FIT_ATTR_PREFIX = "fit_attr/" # fitting_net plain tensor attribute (not in _buffers) +_AM_PREFIX = "am/" +_FIT_PREFIX = "fit/" +_ENV_BOOL_CHOICES = { + "1": True, + "true": True, + "yes": True, + "on": True, + "0": False, + "false": False, + "no": False, + "off": False, +} +_TF32_INFER_PRECISION_CHOICES = { + "0": "highest", + "1": "high", + "2": "medium", +} -def _sezm_structure_key(model: SeZMModel) -> tuple[int, ...]: +def _module_shared_key(module: torch.nn.Module) -> tuple[int, ...]: + """Return the direct identities that define a shared compiled structure.""" + child_ids = tuple(id(child) for child in module.children()) + param_ids = tuple(id(param) for param in module.parameters(recurse=False)) + if child_ids or param_ids: + return child_ids + param_ids + return (id(module),) + + +def _int_tuple(values: Any) -> tuple[int, ...]: + """Return a stable integer tuple for graph-state keys.""" + return tuple(int(value) for value in values) + + +def _int_pair_tuple(values: Any) -> tuple[tuple[int, int], ...]: + """Return a stable integer-pair tuple for graph-state keys.""" + return tuple(sorted(tuple(int(type_id) for type_id in pair) for pair in values)) + + +def _sezm_structure_key(model: SeZMModel) -> tuple[Any, ...]: """Return a key that is equal iff two SeZMModel instances can share a compiled graph. After ``share_params``, the descriptor and fitting-net module objects themselves remain *different* Python objects per task; only their - *submodules* (``_modules`` dict entries) are replaced with shared - references. Using ``id(descriptor)`` or ``id(fitting_net)`` would - therefore always differ between tasks and defeat the cache. - - Fix: use the id of the *first named child* of each module. After - ``share_params(level=0)``, those children are the same Python objects - for all tasks in the same structure group, giving matching keys. - - NOTE: only the FIRST child is sampled, assuming "first child shared => - whole module shared" (true for level=0). Under ``share_params(level=1)`` - only ``type_embedding`` is shared; if it is the first child, two tasks - whose other descriptor weights differ would collapse to the same key and - wrongly reuse one compiled graph. If level=1 + compile is ever used, key - on all param ids instead, e.g. ``frozenset(id(p) for p in desc.parameters())``. + submodules / parameters are replaced with shared references. Using + ``id(descriptor)`` or ``id(fitting_net)`` would therefore always differ + between tasks and defeat the cache. + + The key uses direct child-module identities plus direct parameter + identities. This matches SeZM's ``share_params(level=0)`` implementation + and avoids false sharing when only part of the descriptor is linked. + Non-module state that changes ``core_compute`` branches or masks is + included explicitly. """ - try: - desc = model.atomic_model.descriptor - desc_id = 0 - for _, child in desc.named_children(): - desc_id = id(child) - break - if desc_id == 0: - # Descriptor has no named children (unlikely); fall back. - desc_id = id(desc) - except AttributeError: - desc_id = 0 - try: - fitting = model.atomic_model.fitting_net - for _, child in fitting.named_children(): - return (desc_id, id(child)) - return (desc_id, id(fitting)) - except AttributeError: - return (desc_id, id(model)) + atomic_model = model.atomic_model + descriptor = atomic_model.descriptor + fitting = atomic_model.fitting_net + descriptor_key = _module_shared_key(descriptor) + fitting_key = _module_shared_key(fitting) + descriptor_state = ( + _int_pair_tuple(descriptor.exclude_types), + bool(descriptor.use_triton), + bool(descriptor.use_env_seed), + bool(descriptor.use_gie), + bool(descriptor.random_gamma), + descriptor.charge_spin_embedding is not None, + descriptor.inner_clamp is not None, + descriptor.bridging_switch is not None, + descriptor.inner_clamp_r_inner, + descriptor.inner_clamp_r_outer, + int(descriptor.get_dim_chg_spin()), + ) + fitting_state = ( + _int_tuple(fitting.exclude_types), + bool(fitting.eval_return_middle_output), + ) + atomic_state = ( + _int_tuple(atomic_model.atom_exclude_types), + bool(atomic_model.enable_eval_descriptor_hook), + bool(atomic_model.enable_eval_fitting_last_layer_hook), + ) + model_state = ( + str(model.bridging_method), + model.inter_potential is not None, + float(model.bridging_r_inner), + float(model.bridging_r_outer), + tuple(model.get_type_map()), + ) + return ( + descriptor_key + + fitting_key + + ( + descriptor_state, + fitting_state, + atomic_state, + model_state, + ) + ) def _get_sezm_task_buf_names(model: SeZMModel) -> tuple[str, ...]: @@ -518,27 +650,17 @@ def _get_sezm_task_buf_names(model: SeZMModel) -> tuple[str, ...]: * ``bias_atom_e`` on the fitting net — task-specific per-type bias that differs across tasks after ``share_params``. * ``case_embd`` on the fitting net — task-identity vector used for - multi-task case conditioning; stored as a plain tensor attribute. + multi-task case conditioning. """ names: list[str] = [] - try: - am = model.atomic_model - for bname in ("out_bias", "out_std"): - if am._buffers.get(bname) is not None: - names.append(_AM_PREFIX + bname) - try: - fitting = am.fitting_net - for bname in ("bias_atom_e",): - if fitting._buffers.get(bname) is not None: - names.append(_FIT_PREFIX + bname) - for aname in ("case_embd",): - val = getattr(fitting, aname, None) - if val is not None and torch.is_tensor(val): - names.append(_FIT_ATTR_PREFIX + aname) - except AttributeError: - pass - except AttributeError: - pass + atomic_model = model.atomic_model + fitting = atomic_model.fitting_net + for bname in ("out_bias", "out_std"): + if atomic_model._buffers.get(bname) is not None: + names.append(_AM_PREFIX + bname) + for bname in ("bias_atom_e", "case_embd"): + if fitting._buffers.get(bname) is not None: + names.append(_FIT_PREFIX + bname) return tuple(names) @@ -549,54 +671,19 @@ def _get_sezm_task_buf_vals( """Return the current tensor values for the given promoted-buffer names.""" if not names: return () - am = model.atomic_model - try: - fitting = am.fitting_net - except AttributeError: - fitting = None + atomic_model = model.atomic_model + fitting = atomic_model.fitting_net vals: list[torch.Tensor] = [] for name in names: if name.startswith(_AM_PREFIX): - vals.append(am._buffers[name[len(_AM_PREFIX) :]]) + vals.append(atomic_model._buffers[name[len(_AM_PREFIX) :]]) elif name.startswith(_FIT_PREFIX): - vals.append(fitting._buffers[name[len(_FIT_PREFIX) :]]) # type: ignore[union-attr] - elif name.startswith(_FIT_ATTR_PREFIX): - vals.append(getattr(fitting, name[len(_FIT_ATTR_PREFIX) :])) + vals.append(fitting._buffers[name[len(_FIT_PREFIX) :]]) + else: + raise ValueError(f"Unknown SeZM task-buffer name: {name}") return tuple(vals) -def _parse_optional_env_bool(var_name: str) -> bool | None: - """ - Parse an optional boolean environment variable. - - Parameters - ---------- - var_name - Environment variable name. - - Returns - ------- - bool | None - Parsed boolean value, or ``None`` when the variable is unset. - - Raises - ------ - ValueError - If the environment variable value is not a supported boolean token. - """ - raw_value = os.environ.get(var_name) - if raw_value is None: - return None - normalized = raw_value.strip().lower() - if normalized in {"1", "true", "yes", "on"}: - return True - if normalized in {"0", "false", "no", "off"}: - return False - raise ValueError( - f"{var_name} must be one of 1/0/true/false/yes/no/on/off, got {raw_value!r}" - ) - - def _check_compile_torch_version() -> None: """Fail fast when SeZM compile is requested on unsupported PyTorch.""" version = Version(torch.__version__).release @@ -805,11 +892,28 @@ def __init__( # Maps cache_key -> task_buf_order for this instance so forward() # knows which buffers to pass and in what order. object.__setattr__(self, "_task_buf_order_cache", {}) - # Training follows `use_compile`. Evaluation/inference reads - # `DP_COMPILE_INFER` at init time and falls back to eager when unset. - self._env_use_compile_infer: bool | None = _parse_optional_env_bool( - "DP_COMPILE_INFER" - ) + + # Training follows `use_compile`. Evaluation/inference samples env + # policy at init time so path and precision stay fixed per model. + compile_infer_env = os.environ.get("DP_COMPILE_INFER") + if compile_infer_env is None: + self._env_use_compile_infer: bool | None = None + else: + compile_infer_env = compile_infer_env.strip().lower() + if compile_infer_env not in _ENV_BOOL_CHOICES: + choices = "/".join(_ENV_BOOL_CHOICES) + raise ValueError( + f"DP_COMPILE_INFER must be one of {choices}, " + f"got {compile_infer_env!r}" + ) + self._env_use_compile_infer = _ENV_BOOL_CHOICES[compile_infer_env] + + tf32_infer_env = os.environ.get("DP_TF32_INFER", "0").strip().lower() + if tf32_infer_env not in _TF32_INFER_PRECISION_CHOICES: + raise ValueError( + f"DP_TF32_INFER must be one of 0/1/2, got {tf32_infer_env!r}" + ) + self._tf32_infer_precision = _TF32_INFER_PRECISION_CHOICES[tf32_infer_env] if self.use_compile or self._env_use_compile_infer is True: _check_compile_torch_version() @@ -1092,16 +1196,16 @@ def forward_common_after_nlist( device=extended_coord.device, ) - if self.should_use_compile(): - fp, ap = self.convert_fp_ap( - fp, - ap, - nf=nf, - nloc=nloc, - dtype=extended_coord.dtype, - device=extended_coord.device, - ) - with self.tf32_precision_ctx(): + with self.tf32_precision_ctx(): + if self.should_use_compile(): + fp, ap = self.convert_fp_ap( + fp, + ap, + nf=nf, + nloc=nloc, + dtype=extended_coord.dtype, + device=extended_coord.device, + ) if self.compiled_dens_compute is None or not self._dens_compiled: self.compile_dens() with nvtx_range("SeZM/core_compute_dens"): @@ -1124,19 +1228,19 @@ def forward_common_after_nlist( time.perf_counter() - self._dens_pending_compile_t0, ) self._dens_pending_compile_t0 = None - else: - with nvtx_range("SeZM/core_compute_dens"): - compute_ret = self.core_compute_dens( - extended_coord, - extended_atype, - nlist, - mapping, - force_input=force_input, - noise_mask=noise_mask, - fparam=fp, - aparam=ap, - charge_spin=charge_spin, - ) + else: + with nvtx_range("SeZM/core_compute_dens"): + compute_ret = self.core_compute_dens( + extended_coord, + extended_atype, + nlist, + mapping, + force_input=force_input, + noise_mask=noise_mask, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + ) with nvtx_range("SeZM/post_process"): model_predict = self.post_process_output_dens( compute_ret, @@ -1162,16 +1266,16 @@ def forward_common_after_nlist( else: extended_coord = extended_coord.detach() - if self.should_use_compile(): - fp, ap = self.convert_fp_ap( - fp, - ap, - nf=nf, - nloc=nloc, - dtype=extended_coord.dtype, - device=extended_coord.device, - ) - with self.tf32_precision_ctx(): + with self.tf32_precision_ctx(): + if self.should_use_compile(): + fp, ap = self.convert_fp_ap( + fp, + ap, + nf=nf, + nloc=nloc, + dtype=extended_coord.dtype, + device=extended_coord.device, + ) has_coord_corr = extended_coord_corr is not None cache_key = ( bool(self.training), @@ -1197,9 +1301,16 @@ def forward_common_after_nlist( # each call rather than caching the values at compile time). _task_buf_vals = _get_sezm_task_buf_vals( self, - getattr(self, "_task_buf_order_cache", {}).get(cache_key, ()), + self._task_buf_order_cache[cache_key], ) - with nvtx_range("SeZM/core_compute"): + # NOTE: Inference needs no autograd tape -- the force + # (-dE/dx) is already materialised as forward ops in the + # traced graph, so keeping the tape would only make + # AOTAutograd save the full forward activation set for a + # backward eval never runs (see NOTE 13). Training keeps it + # for the force-loss second derivative. + grad_ctx: Any = nullcontext() if self.training else torch.no_grad() + with nvtx_range("SeZM/core_compute"), grad_ctx: if extended_coord_corr is None: model_predict_lower = compiled_core_compute( extended_coord, @@ -1240,20 +1351,20 @@ def forward_common_after_nlist( ) self._core_compute_pending_compile_t0 = None self._core_compute_pending_compile_key = None - else: - with nvtx_range("SeZM/core_compute"): - model_predict_lower = self.core_compute( - extended_coord, - extended_atype, - nlist, - mapping=mapping, - fparam=fp, - aparam=ap, - charge_spin=charge_spin, - do_atomic_virial=do_atomic_virial, - extra_nlist_sort=self.need_sorted_nlist_for_lower(), - extended_coord_corr=extended_coord_corr, - ) + else: + with nvtx_range("SeZM/core_compute"): + model_predict_lower = self.core_compute( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + extended_coord_corr=extended_coord_corr, + ) with nvtx_range("SeZM/communicate_output"): model_predict = communicate_extended_output( @@ -1728,9 +1839,7 @@ def trace_and_compile( self.compiled_core_compute_cache[cache_key] = _SEZM_COMPILE_CACHE[ full_cache_key ] - self._task_buf_order_cache[cache_key] = _SEZM_TASK_BUF_ORDER.get( - structure_key, () - ) + self._task_buf_order_cache[cache_key] = _SEZM_TASK_BUF_ORDER[structure_key] log.info( "SeZM: reusing shared compiled graph " "(mode=%s, atomic_virial=%s, coord_corr=%s)", @@ -1758,15 +1867,12 @@ def trace_and_compile( # Resolve module references once for the buffer-patching closures. _am_patch = self.atomic_model - try: - _fitting_patch: torch.nn.Module | None = _am_patch.fitting_net - except AttributeError: - _fitting_patch = None + _fitting_patch = _am_patch.fitting_net def _patch_task_bufs( vals: tuple[torch.Tensor, ...], - ) -> dict[str, torch.Tensor | None]: - """Temporarily replace model buffers/attrs with FX proxy tensors. + ) -> dict[str, torch.Tensor]: + """Temporarily replace task-local buffers with FX proxy tensors. Executed at trace time inside compute_fn. make_fx records the proxy tensors as placeholder nodes, so the compiled graph reads them @@ -1774,44 +1880,38 @@ def _patch_task_bufs( block in compute_fn always calls ``_restore_task_bufs`` to leave the model in its original state after tracing. """ - saved: dict[str, torch.Tensor | None] = {} - for name, val in zip(task_buf_names, vals): - if name.startswith(_AM_PREFIX): - actual = name[len(_AM_PREFIX) :] - saved[name] = _am_patch._buffers.get(actual) - _am_patch._buffers[actual] = val - elif name.startswith(_FIT_PREFIX): - actual = name[len(_FIT_PREFIX) :] - saved[name] = ( - _fitting_patch._buffers.get(actual) - if _fitting_patch is not None - else None - ) - if _fitting_patch is not None: + if len(vals) != len(task_buf_names): + raise ValueError( + "SeZM task-buffer placeholder count mismatch: " + f"expected {len(task_buf_names)}, got {len(vals)}" + ) + saved: dict[str, torch.Tensor] = {} + try: + for name, val in zip(task_buf_names, vals): + if name.startswith(_AM_PREFIX): + actual = name[len(_AM_PREFIX) :] + saved[name] = _am_patch._buffers[actual] + _am_patch._buffers[actual] = val + elif name.startswith(_FIT_PREFIX): + actual = name[len(_FIT_PREFIX) :] + saved[name] = _fitting_patch._buffers[actual] _fitting_patch._buffers[actual] = val - elif name.startswith(_FIT_ATTR_PREFIX): - actual = name[len(_FIT_ATTR_PREFIX) :] - saved[name] = getattr(_fitting_patch, actual, None) - if _fitting_patch is not None: - setattr(_fitting_patch, actual, val) + except Exception: + _restore_task_bufs(saved) + raise return saved def _restore_task_bufs( - saved: dict[str, torch.Tensor | None], + saved: dict[str, torch.Tensor], ) -> None: - """Restore original model buffers/attrs after tracing.""" + """Restore original task-local buffers after tracing.""" for name, orig in saved.items(): if name.startswith(_AM_PREFIX): actual = name[len(_AM_PREFIX) :] _am_patch._buffers[actual] = orig elif name.startswith(_FIT_PREFIX): actual = name[len(_FIT_PREFIX) :] - if _fitting_patch is not None: - _fitting_patch._buffers[actual] = orig - elif name.startswith(_FIT_ATTR_PREFIX): - actual = name[len(_FIT_ATTR_PREFIX) :] - if _fitting_patch is not None: - setattr(_fitting_patch, actual, orig) + _fitting_patch._buffers[actual] = orig need_coord_grad = self.do_grad_r() or self.do_grad_c() @@ -1836,9 +1936,9 @@ def _prepare_coord_for_trace(coord: torch.Tensor) -> torch.Tensor: # make_fx treats each element as a separate placeholder so the compiled # graph reads them as live inputs every call — not baked-in constants. # The buffer-patching trick: at trace time the proxy tensors are written - # into _buffers / __dict__ so that downstream code (apply_out_stat, - # fitting_net.forward) reads the proxies and the ops are recorded in the - # FX graph. The finally block restores original state unconditionally. + # into _buffers so downstream code (apply_out_stat, fitting_net.forward) + # reads the proxies and the ops are recorded in the FX graph. The + # finally block restores original state unconditionally. if extended_coord_corr is None: def compute_fn( @@ -2030,32 +2130,39 @@ def compute_fn( # type: ignore[misc] decomposition_table=decomp_table, )(*trace_args) - # NOTE: Only strip autograd-inserted detach chains in training - # mode. With ``create_graph=True`` make_fx wraps every saved - # forward activation in a - # ``fwd_op -> detach_A -> detach_B -> bwd_op`` chain. Those - # detaches are informational in eager autograd but become real - # ops after tracing and sever the gradient path from the force - # loss back to theta -- training would silently emit zero - # parameter updates for the second-derivative term. In eval - # mode ``create_graph=False`` so the chain is never inserted - # and stripping would be wrong. if self.training: + # NOTE: Training is the only mode that needs FX graph repair. + # ``fit_output_to_model_output(create_graph=True)`` asks autograd to + # keep the force graph differentiable with respect to model + # parameters. During ``make_fx`` tracing, autograd represents saved + # forward activations through double-detach chains such as + # + # fwd_op -> detach_A -> detach_B -> bwd_op + # + # These detaches are bookkeeping in eager autograd, but ordinary FX + # operators after tracing. If left in place, they cut the + # second-derivative path from force loss back to theta and training + # silently produces zero updates for that term. Therefore the + # training graph first removes only the autograd-inserted detach + # chains, preserving user-explicit detach nodes by graph topology. _strip_saved_tensor_detach(traced) - # NOTE: Rebuild the FX graph from scratch. - # ``Graph.erase_node`` inside ``_strip_saved_tensor_detach`` - # unlinks nodes from the doubly linked list but on some PyTorch - # builds (observed on 2.11+cu130) leaves stale C-level - # prev/next pointers on neighbouring Node objects. Dynamo later - # re-traces the ``GraphModule`` and walks ``graph.nodes`` inside - # ``output_graph.py:_create_proxy`` to read ``nd.meta``; - # dereferencing one of those stale pointers segfaults the - # process. A single ``node_copy`` pass into a freshly allocated - # ``torch.fx.Graph`` builds an equivalent graph with a clean - # linked list. We always rebuild -- even in eval -- because a - # fresh graph is cheap and a segfault is fatal. - traced = _rebuild_graph_module(traced) + # ``_strip_saved_tensor_detach`` mutates ``traced.graph`` via + # ``Graph.erase_node``. On some PyTorch builds (observed on + # 2.11+cu130), node erasure may leave stale C-level prev/next + # pointers on neighbouring FX nodes; Dynamo can later dereference + # those stale links while re-tracing the GraphModule and segfault. + # Rebuilding copies the graph into a fresh linked list after all + # training-only erasures are complete. + # + # Eval/inference must not take this repair path. In eval, + # ``create_graph=False`` means autograd does not insert the + # double-detach chains, so no nodes are erased. The eval graph also + # contains data-dependent ``nonzero`` output sizes from sparse edge + # compaction; copying that graph can make the resulting unbacked + # symbols fail Dynamo's shape-guard generation. Keeping the original + # eval GraphModule preserves the traced metadata that Inductor needs. + traced = _rebuild_graph_module(traced) # NOTE: Conservative Inductor options keep SeZM's dynamic edge # graph from forming overly large Triton reduction kernels @@ -2106,12 +2213,90 @@ def compute_fn( # type: ignore[misc] # disables every Inductor/Triton feature that has ever # interacted badly with ``make_fx`` + double backward in # this project. - compiled = torch.compile( - traced, - backend="inductor", - dynamic=True, - options=compile_options, - ) + if self.training: + compiled = torch.compile( + traced, + backend="inductor", + dynamic=True, + options=compile_options, + ) + else: + # NOTE 13: Eval lowers the symbolic make_fx graph through + # AOTAutograd's inference path rather than torch.compile. Re-tracing + # with Dynamo re-runs shape-guard production, which on the + # forward-only graph leaves an intermediate view's extended-atom + # (nall) axis without an input source and aborts ("sources must not + # be empty for symbol s..."). AOTAutograd skips Dynamo but still + # functionalizes the graph, letting Inductor reuse buffers (a plain + # torch._inductor.compile does not, at ~3x peak memory). no_grad + # selects the forward-only partition; the fake placeholder values + # keep the lowering dynamic. + from torch._functorch.aot_autograd import ( + aot_module_simplified, + ) + from torch._inductor import config as _ind_cfg + from torch._inductor.compile_fx import ( + compile_fx_inner, + ) + from torch._inductor.decomposition import ( + select_decomp_table, + ) + + example_inputs = [ + node.meta["val"] + for node in traced.graph.nodes + if node.op == "placeholder" + ] + + # AOTAutograd's flat-output contract rejects the dict that + # ``core_compute`` returns. Rewrite the graph's output to a tuple + # of the dict values, remember the keys, and re-pack the dict around + # the compiled callable. Insertion order makes keys and values line + # up. + _output_node = next( + node for node in reversed(traced.graph.nodes) if node.op == "output" + ) + _out_struct = _output_node.args[0] + if isinstance(_out_struct, dict): + _out_keys: list[str] | None = list(_out_struct.keys()) + _output_node.args = (tuple(_out_struct.values()),) + traced.recompile() + else: + _out_keys = None + + def _inductor_inference_compiler( + fx_gm: torch.fx.GraphModule, fx_inputs: list[Any] + ) -> Any: + # max_tiles=1 keeps pointwise grids 1D so the data-dependent + # edge axis stays on Triton's x grid (limit 2**31); the default + # tiling places it on the y/z grid (limit 65535), which + # overflows for large systems. + with _ind_cfg.patch({**compile_options, "triton.max_tiles": 1}): + return compile_fx_inner(fx_gm, fx_inputs) + + # select_decomp_table keeps the decomposition set aligned with + # Inductor's fallback set (a mismatch raises an aten._to_copy + # decomp/fallback clash). The torch.device context pins the model's + # device: AOTAutograd's PhiloxStateTracker allocates an RNG-state + # tensor without an explicit device, which otherwise lands on a + # stray default device and raises "invalid device ordinal". + with torch.no_grad(), torch.device(extended_coord.device): + _compiled_flat = aot_module_simplified( + traced, + example_inputs, + fw_compiler=_inductor_inference_compiler, + inference_compiler=_inductor_inference_compiler, + decompositions=select_decomp_table(), + ) + + if _out_keys is None: + compiled = _compiled_flat + else: + _keys = _out_keys + + def compiled(*args: Any, _fn: Any = _compiled_flat) -> dict[str, Any]: + return dict(zip(_keys, _fn(*args))) + # Populate both per-instance and module-level shared caches. # The shared cache (_SEZM_COMPILE_CACHE) lets a second task with the # same structure key skip re-tracing and re-compiling entirely. @@ -2310,6 +2495,17 @@ def fn( # Neighbor List Construction # ========================================================================= + def use_self_built_nlist(self) -> bool: + """Whether inference should keep this model's own neighbor-list path. + + SeZM builds an O(N) Toolkit-Ops neighbor list in + :meth:`build_neighbor_list` and runs the compiled ``forward`` path, so + the inference driver must not substitute an external ``NeighborList`` + strategy -- that would route through the eager lower interface and bypass + the compiled graph. + """ + return True + def build_neighbor_list( self, coord: Float[Tensor, "nf nloc 3"] | Float[Tensor, "nf nloc_x3"], @@ -2322,7 +2518,14 @@ def build_neighbor_list( Int[Tensor, "nf nloc nsel"], ]: """ - Build extended inputs and neighbor list (traditional path). + Build extended inputs and the neighbor list for the ``forward`` entry. + + Used when the model constructs its own neighbor list from ``coord`` / + ``box``, as opposed to ``forward_lower`` which receives an externally + built nlist (e.g. from LAMMPS or an inference ``NeighborList`` strategy). + Large periodic CUDA systems use the O(N) Toolkit-Ops cell list + (:class:`NvNeighborList`); all other cases use the dense all-pairs + builder. Either way the neighbor list is trimmed to ``sum(sel)``. Parameters ---------- @@ -2336,8 +2539,28 @@ def build_neighbor_list( Returns ------- tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - Extended coordinates, extended atom types, neighbor list, and mapping. + Extended coordinates, extended atom types, mapping, and neighbor list. """ + nloc = atype.shape[1] + if ( + box is not None + and coord.is_cuda + and nloc >= SEZM_NV_NLIST_THRESHOLD + and is_nv_available() + ): + # Large periodic systems: the device-resident O(N) Toolkit-Ops cell + # list avoids the dense all-pairs ghost expansion. It already keeps + # the nearest sum(sel) neighbors (fixed width, like the standard + # builder); only its (nlist, mapping) order is swapped to this + # method's (mapping, nlist) contract. + extended_coord, extended_atype, nlist, mapping = NvNeighborList().build( + coord.view(atype.shape[0], nloc, 3), + atype, + box, + self.get_rcut(), + self.get_sel(), + ) + return extended_coord, extended_atype, mapping, nlist return extend_input_and_build_neighbor_list( coord, atype, @@ -2420,18 +2643,33 @@ def build_edge_list_from_nlist( neighbor_safe = torch.where( valid_flat, neighbor_flat, torch.zeros_like(neighbor_flat) ) - coord_flat = coord_for_diff.flatten(0, 1) - dst_ext = f_idx * nall + dst_local - src_ext = f_idx * nall + neighbor_safe.to(dtype=torch.long) - diff = coord_flat.index_select(0, src_ext) - coord_flat.index_select(0, dst_ext) + # Gather coordinates within each frame instead of flattening + # ``(nf, nall)`` into one index space. The flattened form is + # mathematically correct, but Inductor may lower + # ``coord_flat.index_select(0, f_idx * nall + local_idx)`` to a kernel + # that asserts the composite index against ``nall`` rather than + # ``nf * nall`` when ``nf > 1``. Frame-local gather keeps the same + # differentiable path to coordinates while making every indirect index + # visibly bounded by the atom axis. + neighbor_safe_2d = neighbor_safe.to(dtype=torch.long).view(nf, nloc * nsel) + nei_coord = torch.gather( + coord_for_diff, + 1, + neighbor_safe_2d.unsqueeze(-1).expand(-1, -1, 3), + ).reshape(-1, 3) + dst_coord = torch.gather( + coord_for_diff[:, :nloc, :], + 1, + dst_local.view(nf, -1).unsqueeze(-1).expand(-1, -1, 3), + ).reshape(-1, 3) + diff = nei_coord - dst_coord edge_len2 = torch.sum(diff * diff, dim=-1) # === Step 2. Build compact src/dst (local indices) === if mapping is None: src_local = neighbor_safe.to(dtype=torch.long) else: - mapping_flat = mapping.reshape(-1) - src_local = mapping_flat.index_select(0, f_idx * nall + neighbor_safe) + src_local = torch.gather(mapping, 1, neighbor_safe_2d).reshape(-1) src_actual = f_idx * nloc + src_local.to(dtype=torch.long) # Filter: valid nlist entry AND src in [0, nloc) AND non-zero distance. @@ -2876,20 +3114,20 @@ def deserialize(cls, data: dict[str, Any]) -> SeZMModel: def tf32_precision_ctx(self) -> Generator[None, None, None]: """Context manager to temporarily set TF32 matmul precision. - TF32 is only enabled when the model is in training mode; during - inference we force ``highest`` precision because the reduced - mantissa of TF32 can introduce unacceptable errors in force - predictions and downstream MD trajectories. + Training follows ``enable_tf32``. Eval/inference follows + ``DP_TF32_INFER``: 0 keeps ``highest`` precision, 1 selects + ``high``, and 2 selects ``medium``. """ - if not self.should_use_compile() or not torch.cuda.is_available(): + if not torch.cuda.is_available(): yield return prev_precision = torch.get_float32_matmul_precision() try: - if self.enable_tf32 and self.training: - torch.set_float32_matmul_precision("high") + if self.training: + precision = "high" if self.enable_tf32 else "highest" else: - torch.set_float32_matmul_precision("highest") + precision = self._tf32_infer_precision + torch.set_float32_matmul_precision(precision) yield finally: torch.set_float32_matmul_precision(prev_precision) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 49b87138a7..97ae63589a 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -181,18 +181,15 @@ def __init__( training_params = config["training"] optimizer_params = config.get("optimizer", {}) - # NOTE: Translate ``validating.compiled_infer`` (input.json opt-in) - # into the ``DP_COMPILE_INFER`` environment variable *before* any - # model is constructed below. SeZMModel samples this env var - # exactly once inside its __init__ (see ``_env_use_compile_infer`` - # in ``deepmd/pt/model/model/sezm_model.py``) and uses the cached - # value to decide whether eval / full-validation forwards take - # the compile path. Setting it later would be silently ignored - # for the rest of the run. ``setdefault`` preserves any explicit - # shell-level override so a user who manually exported - # ``DP_COMPILE_INFER`` (either direction) stays in control. - if bool((config.get("validating") or {}).get("compiled_infer", False)): + validating_params = config.get("validating") or {} + # NOTE: Translate eval/inference options from input.json into + # environment variables before any model is constructed below. + # SeZMModel samples these env vars exactly once inside its __init__. + # ``setdefault`` preserves explicit shell-level overrides. + if bool(validating_params.get("compiled_infer", False)): os.environ.setdefault("DP_COMPILE_INFER", "1") + if bool(validating_params.get("tf32_infer", False)): + os.environ.setdefault("DP_TF32_INFER", "1") self.multi_task = "model_dict" in model_params self.finetune_links = finetune_links self.finetune_update_stat = False @@ -1093,7 +1090,6 @@ def single_model_finetune( self.full_validator = None self.ema_full_validator = None - validating_params = config.get("validating") or {} self.full_validator = self._create_full_validator( validating_params=validating_params, validation_data=validation_data, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 762347fde7..34cdf08ad1 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -251,6 +251,21 @@ def get_all_argument(self, exclude_hybrid: bool = False) -> list[Argument]: raise ValueError(f"Invalid return type {type(args)}") return arguments + def get_argument(self, name: str) -> Argument: + """Get one registered argument by its canonical tag.""" + for (arg_name, alias, doc), metd in self.__plugin.plugins.items(): + if arg_name != name: + continue + args = metd() + if isinstance(args, Argument): + return args + if isinstance(args, list): + return Argument( + name=arg_name, dtype=dict, sub_fields=args, alias=alias, doc=doc + ) + raise ValueError(f"Invalid return type {type(args)}") + raise KeyError(f"Unknown argument plugin: {name}") + descrpt_args_plugin = ArgsPlugin() @@ -365,12 +380,12 @@ def descrpt_se_zm_args() -> list[Argument]: doc_channels = "Total channels per (l,m) coefficient." doc_basis_type = "Radial basis type. Supported values are `bessel` and `gaussian`." doc_n_radial = "Number of radial basis functions." - doc_radial_mlp = "Hidden layer sizes for radial networks. An output layer of size (l_schedule[0]+1)*channels will be automatically appended. Use 0 as a placeholder to be replaced by channels." + doc_radial_mlp = "Hidden layer sizes for radial networks. An output layer of size (l_schedule[0]+extra_node_l+1)*channels will be automatically appended. Use 0 as a placeholder to be replaced by channels." doc_use_env_seed = ( "If True, seed the initial node state with local-environment information: " "apply environment matrix FiLM conditioning on l=0 features using 4D " "[s, s*r_hat] representation, and enable the non-scalar geometric initial " - "embedding when l_schedule[0] > 0. If False, the initial state contains " + "embedding when l_schedule[0] + extra_node_l > 0. If False, the initial state contains " "only atom-local scalar features before message passing. Internal dimensions " "are derived from channels: embed_dim=min(channels, 128), " "axis_dim=min(4 if embed_dim < 64 else 8, embed_dim-1), " @@ -391,6 +406,12 @@ def descrpt_se_zm_args() -> list[Argument]: "`l_schedule` and satisfy `m_schedule[i] <= l_schedule[i]`. " "If set, `mmax` will be ignored." ) + doc_extra_node_l = ( + "Extra node representation degree above each message-passing degree. " + "`0` keeps the node representation identical to `l_schedule`. In general, " + "block `i` uses node degree `l_schedule[i] + extra_node_l`, while SO(2) " + "message passing still uses `l_schedule[i]`." + ) doc_n_blocks = "Number of blocks (only used when `l_schedule` is None)." doc_block_attn_res = ( "Descriptor-level block attention residual mode over block history " @@ -498,23 +519,70 @@ def descrpt_se_zm_args() -> list[Argument]: '`activation_function="silu"`. ' "`ffn_enabled=true` makes the block-internal FFN path use " '`activation_function="silu"` and `glu_activation=true`. ' - "S2-grid resolutions are resolved automatically per block. The e3nn " - "SO(2) grid is `[2 * mmax + 4, ceil_even(3 * lmax + 2)]`, and the " - "e3nn FFN grid is lifted to `[max(R_phi, R_theta), max(R_phi, R_theta)]`. " + "S2-grid resolutions are resolved automatically per block. The tensor-product " + "SO(2) grid uses the message-passing lmax as `[2 * mmax + 4, ceil_even(3 * lmax + 2)]`, " + "and the tensor-product FFN grid is lifted from the node lmax to `[max(R_phi, R_theta), max(R_phi, R_theta)]`. " "Lebedev branches use the smallest packaged rule with precision at " "least `3 * lmax`. " "The final scalar output FFN is unchanged." ) + doc_ffn_so3_grid = ( + "If True, use the Wigner-D SO(3) grid in the block-internal FFN. " + "This option takes precedence over the FFN grid path and ignores " + "`s2_activation[1]`; the SO(2) branch still follows `s2_activation[0]`." + ) + doc_node_wise_s2 = ( + "If True, enable an edge-local S2 pointwise product branch between " + "source and destination node features inside the SO(2) convolution." + ) + doc_node_wise_so3 = ( + "If True, enable the corresponding edge-local SO(3) Wigner-D grid-net " + "branch. It uses the source side as query and the destination side as " + "context. When enabled together with `node_wise_s2`, the SO(3) branch " + "is used for this path." + ) + doc_message_node_s2 = ( + "If True, enable a post-aggregation S2 pointwise product branch between " + "hidden messages and destination node features inside the SO(2) convolution." + ) + doc_message_node_so3 = ( + "If True, enable the corresponding post-aggregation SO(3) Wigner-D " + "grid-net branch. The message is used as query and the node state as " + "context. When enabled together with `message_node_s2`, the SO(3) " + "branch is used for this path." + ) doc_lebedev_quadrature = ( "Either one boolean applied to both S2 branches, or two booleans " "`[so2_enabled, ffn_enabled]` aligned with `s2_activation`. If a branch " "is enabled here, its S2 projector uses packaged Lebedev quadrature " - "rules instead of the e3nn product grid. The default keeps the existing " - "e3nn behavior." - ) - doc_grid_ffn = ( - "If True, use the optional grid-MLP structure for the block-internal " - "equivariant FFN. This does not change the final `l=0` output head." + "rules instead of the tensor-product sphere grid. The default enables " + "Lebedev quadrature for both S2 branches." + ) + doc_grid_mlp = ( + "Either one boolean applied to every grid path, or three booleans " + "`[node_wise, message_node, ffn]` selecting the polynomial point-wise " + "grid MLP operation per grid path. The grid MLP projects the two grid " + "fields, multiplies them point-wise, and projects the result back to " + "grid channels. On any path whose `grid_branch` entry is positive it is " + "overridden by branch mixing, and it has no effect on the final `l=0` " + "output head." + ) + doc_grid_branch = ( + "Either one non-negative integer applied to every grid path, or three " + "integers `[node_wise, message_node, ffn]` setting the number of " + "scalar-routed polynomial product branches per grid path. `0` disables " + "branch mixing on that path; positive values select branch mixing and " + "take precedence over `grid_mlp`. Branch weights are computed from " + "`l=0` scalar features only, while each branch is a quadratic product " + "of two channel-mixed grid fields. The `node_wise` and `message_node` " + "entries control the SO(2) convolution cross-grid paths, and the `ffn` " + "entry controls the block-internal FFN grid path." + ) + doc_kmax = ( + "Maximum Wigner-D frame order used by SO(3) grid nets. The frame set is " + "`[0, -1, 1, ..., -kmax, kmax]`. `kmax=1` is the default low-cost " + "setting that opens odd/antisymmetric coupling paths. The gamma grid is " + "resolved internally from `kmax`." ) doc_activation_function = ( f"Base activation function for helper MLPs, the SO(2) gated activation " @@ -606,9 +674,27 @@ def descrpt_se_zm_args() -> list[Argument]: default=1, doc=doc_mmax, ), + Argument( + "kmax", + int, + optional=True, + default=1, + extra_check=lambda x: x >= 0, + extra_check_errmsg="must be >= 0", + doc=doc_only_pt_supported + doc_kmax, + ), Argument( "m_schedule", list[int], optional=True, default=None, doc=doc_m_schedule ), + Argument( + "extra_node_l", + int, + optional=True, + default=0, + extra_check=lambda x: x >= 0, + extra_check_errmsg="must be >= 0", + doc=doc_extra_node_l, + ), Argument("n_blocks", int, optional=True, default=3, doc=doc_n_blocks), Argument("so2_norm", bool, optional=True, default=False, doc=doc_so2_norm), Argument("so2_layers", int, optional=True, default=4, doc=doc_so2_layers), @@ -682,10 +768,28 @@ def descrpt_se_zm_args() -> list[Argument]: ), Argument( "grid_mlp", - bool, + [bool, list[bool]], optional=True, default=False, - doc=doc_only_pt_supported + doc_grid_ffn, + extra_check=lambda x: isinstance(x, bool) or len(x) == 3, + extra_check_errmsg="must be a boolean or a list of three booleans: [node_wise, message_node, ffn]", + doc=doc_only_pt_supported + doc_grid_mlp, + ), + Argument( + "grid_branch", + [int, list[int]], + optional=True, + default=0, + extra_check=lambda x: ( + (isinstance(x, int) and x >= 0) + or ( + isinstance(x, list) + and len(x) == 3 + and all(isinstance(i, int) and i >= 0 for i in x) + ) + ), + extra_check_errmsg="must be a non-negative int or a list of three non-negative ints: [node_wise, message_node, ffn]", + doc=doc_only_pt_supported + doc_grid_branch, ), Argument( "ffn_blocks", @@ -742,6 +846,41 @@ def descrpt_se_zm_args() -> list[Argument]: extra_check_errmsg="must be a list of two booleans: [so2_activation, ffn_activation]", doc=doc_only_pt_supported + doc_s2_activation, ), + Argument( + "ffn_so3_grid", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_ffn_so3_grid, + ), + Argument( + "node_wise_s2", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_node_wise_s2, + ), + Argument( + "node_wise_so3", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_node_wise_so3, + ), + Argument( + "message_node_s2", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_message_node_s2, + ), + Argument( + "message_node_so3", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_message_node_so3, + ), Argument( "lebedev_quadrature", [bool, list[bool]], @@ -3080,14 +3219,30 @@ def sezm_model_args() -> Argument: "descriptor", dict, [], - [descrpt_variant_type_args()], + [ + Variant( + "type", + [descrpt_args_plugin.get_argument("dpa4")], + optional=True, + default_tag="dpa4", + doc="The type of the descriptor.", + ) + ], doc=doc_only_pt_supported + doc_descrpt, ), Argument( "fitting_net", dict, [], - [fitting_variant_type_args()], + [ + Variant( + "type", + [fitting_args_plugin.get_argument("dpa4_ener")], + optional=True, + default_tag="dpa4_ener", + doc="The type of the fitting.", + ) + ], doc=doc_only_pt_supported + doc_fitting, ), Argument( @@ -5193,6 +5348,14 @@ def validating_args() -> Argument: "meaningful when `model.use_compile=true`; has no effect on models " "that do not implement the SeZM-style eval compile path." ) + doc_tf32_infer = ( + "Whether to enable TF32 `high` matmul precision for eval-time forwards " + "(including regular validation and full validation). When `true`, this " + "flag is translated into `DP_TF32_INFER=1` at trainer startup before any " + "model is constructed. A manually exported `DP_TF32_INFER` takes " + "precedence over this option. This does not affect training forwards, " + "which are controlled by `model.enable_tf32`." + ) args = [ Argument( "full_validation", @@ -5268,6 +5431,13 @@ def validating_args() -> Argument: default=False, doc=doc_only_pt_supported + doc_compiled_infer, ), + Argument( + "tf32_infer", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_tf32_infer, + ), ] return Argument( "validating", diff --git a/doc/model/dpa4.md b/doc/model/dpa4.md index daff2f42fd..f6a6638ef2 100644 --- a/doc/model/dpa4.md +++ b/doc/model/dpa4.md @@ -4,36 +4,79 @@ **Supported backends**: PyTorch {{ pytorch_icon }} ::: -DPA4/SeZM is the DPA-series implementation of SeZM, the Smooth Equivariant -Zone-bridging Model. The recommended input type is `dpa4`; `DPA4`, `SeZM`, -and `sezm` are accepted aliases. For new input files, set -`model.type: "dpa4"` and `descriptor.type: "dpa4"`. +DPA4 is the DeePMD-kit implementation of the SeZM (Smooth Equivariant +Zone-bridging Model) architecture. Use `model.type: "dpa4"` in new input +files. The aliases `DPA4`, `SeZM`, and `sezm` are accepted for the same +implementation. The DPA4 model scaffold uses the SeZM descriptor and the +`dpa4_ener` fitting network, so `descriptor.type` and `fitting_net.type` +may be omitted in ordinary DPA4 inputs. + +Reference: [DPA4 paper](https://arxiv.org/abs/2606.02419). Training example: `examples/water/dpa4/input.json`. +Quick start: + +```bash +cd examples/water/dpa4 +dp --pt train input.json +``` + ## Overview DPA4/SeZM is an SO(3)-equivariant message-passing model for conservative -interatomic potentials. It predicts atomic energies and obtains forces -and virials by differentiating the energy, following the same -conservative formulation used by standard DeePMD energy models: +interatomic potentials. It predicts atomic energies and obtains forces and +virials by differentiating the energy, following the same conservative +formulation used by standard DeePMD energy models: ```math \mathbf{F}_i = -\frac{\partial E}{\partial \mathbf{r}_i}. ``` -The model retains vector and higher-order angular information during -descriptor construction. Only the final descriptor passed to the fitting -network is scalar. This separates the geometric representation from the -energy mapping: equivariant layers encode local geometry, and the fitting -network maps the resulting scalar features to atomic energies. +The model keeps vector and higher-order angular information while building +the descriptor. Only the final descriptor sent to the fitting network is +scalar. This separates geometric representation from energy prediction: +equivariant layers encode local environments, and the fitting network maps +the resulting scalar features to atomic energies. + +## Model scaffold + +The DPA4 model type is a convenience scaffold around the SeZM descriptor and +the `dpa4_ener` energy fitting network. A minimal input therefore only needs +the model type, `type_map`, and descriptor settings such as `sel` and `rcut`: + +```json +{ + "model": { + "type": "dpa4", + "type_map": [ + "O", + "H" + ], + "descriptor": { + "sel": 120, + "rcut": 6.0 + } + } +} +``` + +Options that are not written in the input use their documented defaults. +The neighbor selection `sel` may be an integer total neighbor limit, a +per-type list, or `auto` / `auto:factor`. + +Internally, the PyTorch model builds a standard DeePMD neighbor list for the +public forward path. When `use_compile` is enabled, the model additionally +uses a compact sparse-edge path for compiled training. Both paths share the +same descriptor and fitting definitions. ## Descriptor construction For each frame, DPA4/SeZM first builds a local neighbor graph within cutoff -radius `rcut`. Each edge stores the displacement vector, a smooth cutoff -weight, radial basis features, and a rotation from the global coordinate -frame to an edge-aligned local frame. +radius `rcut`. Each edge stores the displacement vector, smooth cutoff +weights, radial basis features, and the rotation between the global frame and +an edge-aligned local frame. These edge features are built once per forward +call and reused by all interaction blocks. One DPA4/SeZM interaction block consists of the following operations: @@ -59,56 +102,76 @@ DPA4/SeZM stores intermediate features as SO(3)-equivariant coefficients. A feature block with maximum degree `lmax` contains all degrees `l = 0, ..., lmax`, and each degree has `2l + 1` angular components. -DPA4/SeZM avoids the most expensive part of a full SO(3) operation by working -in a local frame on each edge. In that frame, rotations around the edge -axis become SO(2) operations. The descriptor retains only orders -`|m| <= mmax` inside the SO(2) convolution, reducing angular cost while -preserving the required rotation behavior. +The model reduces angular cost by working in a local frame on each edge. In +that frame, rotations around the edge axis become SO(2) operations. The SO(2) +convolution retains orders `|m| <= mmax`, or the per-block value specified by +`m_schedule`, while preserving the required equivariant transformation +behavior. Two schedules control the angular width: - `l_schedule` sets the SO(3) degree used by each block. A schedule such as - `[3, 3, 2]` uses higher degrees in early blocks and truncates them in - later blocks. + `[3, 3, 2]` uses higher degrees in early blocks and truncates them in later + blocks. - `mmax` or `m_schedule` sets how many SO(2) orders are retained in the edge-local convolution. The angular schedule is one of the primary accuracy-cost controls in DPA4/SeZM. Larger angular spaces can represent more complex local chemistry, -but the cost grows quickly with `lmax`. For many systems, a -non-increasing `l_schedule` provides a practical compromise. +but the cost grows quickly with `lmax`. For many systems, a non-increasing +`l_schedule` provides a practical compromise. ## Radial basis and smooth cutoff -Every edge uses a radial basis multiplied by a smooth envelope. The -default basis is Bessel-like, and a Gaussian basis is also available. The -cutoff envelope is constructed so that its value and first three -derivatives vanish at `rcut`. This smoothness is important for molecular -dynamics because nonsmooth descriptor cutoffs would be inherited by force -derivatives. +Every edge uses a radial basis multiplied by a smooth envelope. The default +basis is Bessel-like, and a Gaussian basis is also available through +`basis_type`. The cutoff envelope is constructed so that its value and first +three derivatives vanish at `rcut`. This smoothness is important for +molecular dynamics because nonsmooth descriptor cutoffs would be inherited by +force derivatives. DPA4/SeZM uses two envelope exponents through `env_exp`: - the first exponent controls the radial basis envelope, -- the second controls message-passing edge weights. +- the second exponent controls message-passing edge weights. -Increasing the exponent keeps the envelope closer to one for more of the -cutoff range before it drops near `rcut`. +Increasing an exponent keeps the corresponding envelope closer to one for +more of the cutoff range before it drops near `rcut`. ## Attention and focus streams -DPA4/SeZM can aggregate edge messages either by envelope-weighted scatter or by -attention. When attention is enabled, the cutoff envelope also -participates in the softmax normalization. Edges near the cutoff are therefore -smoothly suppressed in both the numerator and the denominator, avoiding -nonsmooth contributions from the normalization term. - -The SO(2) convolution can also use multiple focus streams. These streams -process the same edge geometry in parallel and are then combined through -scalar weights. This design is not a sparse mixture of experts: all focus -streams are evaluated before soft reweighting. The additional capacity -helps the convolution distinguish different local patterns while -preserving equivariance. +DPA4/SeZM can aggregate edge messages either by envelope-weighted scatter or +by attention. When attention is enabled with `n_atten_head > 0`, the cutoff +envelope also participates in the softmax normalization. Edges near the +cutoff are therefore smoothly suppressed in both the numerator and the +denominator, avoiding nonsmooth contributions from the normalization term. + +The SO(2) convolution can also use multiple focus streams through `n_focus`. +These streams process the same edge geometry in parallel and are then +combined through scalar weights. This design is not a sparse mixture of +experts: all focus streams are evaluated before soft reweighting. The +additional capacity helps the convolution distinguish different local +patterns while preserving equivariance. + +## Grid nonlinearities + +Several DPA4/SeZM branches can use sphere-grid or SO(3)-grid nonlinearities +inside the equivariant network. The most commonly used public switches are: + +- `s2_activation`, which enables S2-grid nonlinearities for the SO(2) branch + and/or the block-internal feed-forward branch. +- `ffn_so3_grid`, which uses an SO(3) Wigner-D grid in the block-internal + feed-forward path. +- `lebedev_quadrature`, which selects packaged Lebedev quadrature rules for + enabled S2-grid branches. +- `grid_mlp` and `grid_branch`, which select the polynomial point-wise MLP or + the scalar-routed polynomial branch mixer for each grid path. Each is either + a single value applied to every path or a list + `[node_wise, message_node, ffn]`. + +These options affect the expressiveness and cost of the equivariant +nonlinearity. The final `l = 0` output descriptor remains a scalar feature +tensor consumed by the fitting network. ## Environment-seeded initial features @@ -125,26 +188,26 @@ model closed over the one-hop neighbor shell. ## Zone bridging and ZBL -DPA4/SeZM includes an optional short-range bridge for analytical repulsion. The -typical use case is ZBL: +DPA4/SeZM includes an optional short-range bridge for analytical repulsion. +The typical use case is ZBL: ```math E_i = E_i^{\mathrm{DPA4/SeZM}} + E_i^{\mathrm{ZBL}}. ``` The purpose of zone bridging is to combine the analytical short-range -repulsion with the learned model while preventing uncontrolled learned -forces in the same protected region. +repulsion with the learned model while preventing uncontrolled learned forces +in the same protected region. Zone bridging has two pieces: 1. Distances below `bridging_r_inner` are clamped before they enter the descriptor. Between `bridging_r_inner` and `bridging_r_outer`, a smooth polynomial transitions back to the true distance. -1. A source gate suppresses message propagation from atoms involved in - frozen short-range pairs. This blocks multi-hop leakage, where a third - atom could otherwise carry information about the frozen pair back into - the learned energy. +1. A source gate suppresses message propagation from atoms involved in frozen + short-range pairs. This blocks multi-hop leakage, where a third atom could + otherwise carry information about the frozen pair back into the learned + energy. This gives a controlled decomposition in the protected region: @@ -152,8 +215,8 @@ This gives a controlled decomposition in the protected region: E_\mathrm{total}(r) = E_\mathrm{ZBL}(r) + E_\mathrm{model}(\tilde r), ``` -where $r$ is the true distance and $\tilde r$ is the clamped distance seen -by the descriptor. +where $r$ is the true distance and $\tilde r$ is the clamped distance seen by +the descriptor. Enable zone bridging with: @@ -174,8 +237,8 @@ complete ZBL input example. ## Fitting network -DPA4/SeZM uses `dpa4_ener` as the energy fitting network name in input files. -It is a GLU-based fitting network that maps scalar descriptors to atomic +DPA4/SeZM uses the `dpa4_ener` energy fitting implementation. It is selected +automatically by the DPA4 model scaffold and maps scalar descriptors to atomic energies. The fitting network uses the same common keys as DeePMD's standard energy @@ -188,13 +251,13 @@ fitting network: - `numb_fparam` - `numb_aparam` -The hidden layers use GLU-style transformations. If `neuron` is `[0]`, -the fitting network uses a direct projection from descriptor channels to -atomic energy. This compact setting is useful for small examples and quick +The hidden layers use GLU-style transformations. If `neuron` is `[0]`, the +fitting network uses a direct projection from descriptor channels to atomic +energy. This compact setting is useful for small examples and quick validation tests. -For shared-fitting multitask training, DPA4/SeZM supports case embeddings. With -`case_film_embd: true`, the case vector modulates the fitting network +For shared-fitting multitask training, DPA4/SeZM supports case embeddings. +With `case_film_embd: true`, the case vector modulates the fitting network instead of being concatenated directly to the descriptor. This keeps the descriptor case-independent while allowing the energy map to depend on the task branch. @@ -202,13 +265,26 @@ task branch. ## Configuration For a complete training input, see `examples/water/dpa4/input.json`. The -example keeps the water dataset paths local to the repository while using a -parameter set close to the pretrained DPA4-Air model. +example uses a compact water setup with the DPA4 model type, SeZM descriptor +options, `dpa4_ener` fitting settings, and the standard conservative energy +loss. Its structure is closer to a DPA4-Neo-style compact configuration than +to the DPA4-Air pretrained configuration. + +Common descriptor controls include: + +- `sel` and `rcut` for the neighbor list. +- `channels`, `n_radial`, and `basis_type` for feature width and radial + resolution. +- `lmax`, `l_schedule`, `mmax`, and `m_schedule` for angular resolution. +- `n_blocks`, `so2_layers`, and `ffn_blocks` for network depth. +- `n_focus` and `n_atten_head` for focus streams and attention aggregation. +- `use_env_seed`, `s2_activation`, `ffn_so3_grid`, and `message_node_so3` for + the main geometric feature paths. +- `use_amp` and `precision` for training precision. ## Training modes -The recommended training objective is the standard conservative energy -loss: +The recommended training objective is the standard conservative energy loss: ```json { @@ -232,13 +308,14 @@ DPA4/SeZM also has an experimental direct-force denoising mode selected by: } ``` -Use `dens` only when the direct-force denoising head is required. It is -not the default training path. +Use `dens` only when the direct-force denoising head is required. It is not +the default training path. See `examples/water/dpa4/input_dens.json` for an +example input. ## Spin -DPA4/SeZM supports the DeePMD-kit spin convention in the PyTorch backend. Keep -the DPA4/SeZM type string and add the standard `model.spin` block: +DPA4/SeZM supports the DeePMD-kit spin convention in the PyTorch backend. +Keep the DPA4/SeZM type string and add the standard `model.spin` block: ```json { @@ -258,7 +335,6 @@ the DPA4/SeZM type string and add the standard `model.spin` block: ] }, "descriptor": { - "type": "dpa4", "sel": 120, "rcut": 6.0 } @@ -269,18 +345,20 @@ the DPA4/SeZM type string and add the standard `model.spin` block: The spin path supports the conservative `ener_spin` loss. The direct-force denoising mode is not used together with spin. See [training spin energy models](train-energy-spin.md) for the common spin -training settings. +training settings, and `examples/water/dpa4/input-spin.json` for a DPA4-style +input example. ## Performance and hardware recommendations ### bfloat16 automatic mixed precision DPA4/SeZM supports automatic mixed precision (AMP) during training through the -descriptor option `use_amp`. This option uses bfloat16 (bf16) autocast for -eligible CUDA operations. In typical DPA4/SeZM workloads, bf16 AMP reduces -memory usage and may improve throughput while preserving fitted accuracy, but -the final accuracy should be validated for the target system. Numerically -sensitive geometric operations are kept in promoted precision. +descriptor option `use_amp`, whose default value is `true`. This option uses +bfloat16 (bf16) autocast for eligible CUDA operations. In typical DPA4/SeZM +workloads, bf16 AMP reduces memory usage and may improve throughput while +preserving fitted accuracy; no visible accuracy degradation is expected in +normal DPA4/SeZM training. Numerically sensitive geometric operations are kept +in promoted precision. When the GPU provides native bf16 support, enabling `use_amp` is recommended: @@ -294,10 +372,22 @@ When the GPU provides native bf16 support, enabling `use_amp` is recommended: } ``` -On GPUs without native bf16 support, leave `use_amp` disabled to avoid runtime -errors or additional conversion overhead. On NVIDIA hardware, native bf16 -support starts with the Ampere generation, including A100-series accelerators -and RTX 30-series GPUs, and continues on newer architectures. +On GPUs without native bf16 support, explicitly set `use_amp` to `false` to +avoid runtime errors or additional conversion overhead: + +```json +{ + "model": { + "descriptor": { + "use_amp": false + } + } +} +``` + +On NVIDIA hardware, native bf16 support starts with the Ampere generation, +including A100-series accelerators and RTX 30-series GPUs, and continues on +newer architectures. ### Experimental `torch.compile` path @@ -311,12 +401,10 @@ DPA4/SeZM can train through an experimental `torch.compile` path: } ``` -This path is useful for force-loss training because the model first -differentiates energy to obtain forces and then differentiates the force -loss with respect to model parameters. The training graph therefore contains -higher-order autograd operations, including mixed derivatives induced by -differentiating force losses with respect to model parameters. DPA4/SeZM -traces this graph before passing it to Inductor. +This path is useful for force-loss training, where differentiating the force +loss requires higher-order derivatives through the conservative +energy-gradient path. DPA4/SeZM traces this path before passing it to +Inductor. This path is experimental and may expose PyTorch compiler issues. It currently requires `torch==2.11`; other PyTorch versions are not supported for this @@ -325,7 +413,26 @@ Silicon Macs are also supported. It has been tested with Python 3.13. If the compiled path fails or produces unexpected behavior, please report the issue with the PyTorch version, CUDA version, GPU model, and a minimal input file. -For evaluation-time compile during validation, set: +### Inference environment variables + +DPA4/SeZM reads inference-related environment variables when the PyTorch model +is constructed. If these variables are already exported in the shell, they +take precedence over values written in the input file. Changing them after +model construction does not affect that model instance. + +`DP_COMPILE_INFER` controls whether evaluation and inference forwards use the +DPA4/SeZM compile path: + +```bash +export DP_COMPILE_INFER=1 +``` + +Accepted true values are `1`, `true`, `yes`, and `on`; accepted false values +are `0`, `false`, `no`, and `off`. Enabling this path has the same PyTorch +version requirements as `model.use_compile`. + +During training validation, the same setting can be requested in the input +file: ```json { @@ -335,35 +442,53 @@ For evaluation-time compile during validation, set: } ``` -You can also set `DP_COMPILE_INFER=1` in the environment before training. +The trainer translates this option into `DP_COMPILE_INFER=1` before model +construction, unless the shell environment already defines `DP_COMPILE_INFER`. + +`DP_TF32_INFER` controls the float32 matmul precision used by evaluation and +inference forwards on CUDA: + +- `0`: use PyTorch `highest` precision. This is the default. +- `1`: use PyTorch `high` precision. +- `2`: use PyTorch `medium` precision. + +During training validation, the input option +`validating.tf32_infer: true` is translated into `DP_TF32_INFER=1` before +model construction, again without overriding an explicitly exported +environment variable. Training forwards are controlled separately by +`model.enable_tf32`. + +For molecular dynamics and other workflows that are sensitive to potential +energy surface smoothness, keep `DP_TF32_INFER=0`. Enabling TF32 inference may +leave energy and force MAE nearly unchanged while making the potential energy +surface less smooth. For less smoothness-sensitive evaluation or screening +workloads, `DP_TF32_INFER=1` or `2` may be useful for improving throughput. ### Hardware selection DPA4/SeZM is designed for fp32 training and inference. Hardware selection should therefore be based primarily on fp32 throughput rather than fp64 throughput. In contrast to workloads dominated by double-precision linear -algebra, DPA4/SeZM does not require GPUs with especially strong fp64 performance. +algebra, DPA4/SeZM does not require GPUs with especially strong fp64 +performance. For practical training, prefer GPUs that combine high fp32 FLOPS with native bf16 support. Native bf16 enables the recommended AMP path, lowering memory usage and often improving throughput. Because AMP can substantially reduce the activation memory footprint, DPA4/SeZM training usually does not require unusually large-memory GPUs once the target system and batch size fit. In that -regime, native bf16 support and fp32 FLOPS are usually more important selection -criteria than maximum device memory. +regime, native bf16 support and fp32 FLOPS are usually more important +selection criteria than maximum device memory. ## LoRA fine-tuning -DPA4/SeZM supports LoRA adapters on its SO(3) and SO(2) linear layers. A typical -input block is: +DPA4/SeZM supports LoRA adapters on its SO(3) and SO(2) linear layers. This +mode is intended for single-task fine-tuning. A typical input block is: ```json { "model": { "type": "dpa4", - "descriptor": { - "type": "dpa4" - }, "lora": { "rank": 16, "alpha": 16.0 @@ -390,13 +515,14 @@ dp --pt freeze -c model.ckpt.pt -o frozen_model ``` The PyTorch backend detects DPA4/SeZM and writes `frozen_model.pt2`. Use this -file with LAMMPS: +`.pt2` file with LAMMPS: ```lammps pair_style deepmd frozen_model.pt2 pair_coeff * * O H ``` +The ordinary TorchScript freeze path is not used for DPA4/SeZM checkpoints. A small LAMMPS example is in `examples/water/dpa4/lmp/`. ## Data format @@ -411,3 +537,23 @@ downstream `pair_coeff` mapping. - Model compression is not supported. - Export uses `.pt2`; the ordinary TorchScript freeze path is not used for DPA4/SeZM checkpoints. + +## Citation + +If you use DPA4/SeZM, please cite the [DPA4 paper](https://arxiv.org/abs/2606.02419): + +```bibtex +@article{li2026dpa4, + title = {{DPA4}: Pushing the Accuracy-Cost Frontier of Interatomic + Potentials with {EMFA} {SO(2)} Convolution}, + author = {Li, Tiancheng and Li, Wentao and Peng, Anyang and Xue, Jianming + and Zhang, Linfeng and Zhang, Duo and Wang, Han}, + journal = {arXiv preprint arXiv:2606.02419}, + year = {2026}, + eprint = {2606.02419}, + archivePrefix = {arXiv}, + primaryClass = {physics.chem-ph}, + doi = {10.48550/arXiv.2606.02419}, + url = {https://arxiv.org/abs/2606.02419} +} +``` diff --git a/examples/water/dpa4/README.md b/examples/water/dpa4/README.md index 067ce609ca..15d6a383f0 100644 --- a/examples/water/dpa4/README.md +++ b/examples/water/dpa4/README.md @@ -6,8 +6,8 @@ water example dataset. The recommended model and descriptor type is `DPA4`; Input files: -- `input.json`: baseline conservative energy training, using a parameter set - close to the pretrained DPA4-Air model. +- `input.json`: baseline conservative energy training, using a compact + DPA4-Neo-style parameter set. - `input-zbl.json`: energy training with ZBL zone bridging. - `input-spin.json`: spin-energy training with the DeePMD spin convention. - `input_dens.json`: direct-force denoising training. diff --git a/examples/water/dpa4/input-spin.json b/examples/water/dpa4/input-spin.json index 6077ed273c..c82754212d 100644 --- a/examples/water/dpa4/input-spin.json +++ b/examples/water/dpa4/input-spin.json @@ -16,54 +16,32 @@ ] }, "descriptor": { - "type": "DPA4", "sel": 120, "rcut": 6.0, - "env_exp": [ - 7, - 5 - ], - "channels": 64, + "channels": 32, "n_radial": 16, - "radial_mlp": [ - 0 - ], "use_env_seed": true, - "random_gamma": true, "lmax": 3, "mmax": 1, - "n_blocks": 3, - "so2_layers": 4, - "so2_norm": false, - "so2_attn_res": "none", + "n_blocks": 2, + "so2_layers": 3, "radial_so2_mode": "degree_channel", "radial_so2_rank": 1, - "n_focus": 1, + "n_focus": 2, "focus_dim": 0, "n_atten_head": 1, - "atten_f_mix": false, - "atten_v_proj": false, - "atten_o_proj": false, "ffn_neurons": 0, + "ffn_so3_grid": true, "grid_mlp": false, - "ffn_blocks": 1, + "grid_branch": 1, + "ffn_blocks": 2, "sandwich_norm": [ false, true, true, false ], - "mlp_bias": false, - "layer_scale": false, - "full_attn_res": "none", - "block_attn_res": "none", - "s2_activation": [ - false, - true - ], - "lebedev_quadrature": true, - "activation_function": "silu", - "glu_activation": true, + "message_node_so3": true, "use_amp": true, "precision": "float32", "seed": 42 @@ -72,7 +50,6 @@ "neuron": [ 0 ], - "activation_function": "silu", "precision": "float32", "seed": 42 }, @@ -99,9 +76,6 @@ }, "optimizer": { "type": "HybridMuon", - "muon_mode": "slice", - "magma_muon": true, - "lr_adjust": 0.0, "weight_decay": 0.001 }, "training": { diff --git a/examples/water/dpa4/input-zbl.json b/examples/water/dpa4/input-zbl.json index 622966d07d..14d675dcfc 100644 --- a/examples/water/dpa4/input-zbl.json +++ b/examples/water/dpa4/input-zbl.json @@ -7,54 +7,32 @@ "H" ], "descriptor": { - "type": "DPA4", "sel": 120, "rcut": 6.0, - "env_exp": [ - 7, - 5 - ], - "channels": 64, + "channels": 32, "n_radial": 16, - "radial_mlp": [ - 0 - ], "use_env_seed": true, - "random_gamma": true, "lmax": 3, "mmax": 1, - "n_blocks": 3, - "so2_layers": 4, - "so2_norm": false, - "so2_attn_res": "none", + "n_blocks": 2, + "so2_layers": 3, "radial_so2_mode": "degree_channel", "radial_so2_rank": 1, - "n_focus": 1, + "n_focus": 2, "focus_dim": 0, "n_atten_head": 1, - "atten_f_mix": false, - "atten_v_proj": false, - "atten_o_proj": false, "ffn_neurons": 0, + "ffn_so3_grid": true, "grid_mlp": false, - "ffn_blocks": 1, + "grid_branch": 1, + "ffn_blocks": 2, "sandwich_norm": [ false, true, true, false ], - "mlp_bias": false, - "layer_scale": false, - "full_attn_res": "none", - "block_attn_res": "none", - "s2_activation": [ - false, - true - ], - "lebedev_quadrature": true, - "activation_function": "silu", - "glu_activation": true, + "message_node_so3": true, "use_amp": true, "precision": "float32", "seed": 42 @@ -63,7 +41,6 @@ "neuron": [ 0 ], - "activation_function": "silu", "precision": "float32", "seed": 42 }, @@ -95,9 +72,6 @@ }, "optimizer": { "type": "HybridMuon", - "muon_mode": "slice", - "magma_muon": true, - "lr_adjust": 0.0, "weight_decay": 0.001 }, "training": { diff --git a/examples/water/dpa4/input.json b/examples/water/dpa4/input.json index e4fb10f3ed..34e316086a 100644 --- a/examples/water/dpa4/input.json +++ b/examples/water/dpa4/input.json @@ -7,54 +7,40 @@ "H" ], "descriptor": { - "type": "DPA4", "sel": 120, "rcut": 6.0, - "env_exp": [ - 7, - 5 - ], - "channels": 64, + "channels": 32, "n_radial": 16, - "radial_mlp": [ - 0 - ], "use_env_seed": true, - "random_gamma": true, "lmax": 3, "mmax": 1, - "n_blocks": 3, - "so2_layers": 4, - "so2_norm": false, - "so2_attn_res": "none", + "n_blocks": 2, + "so2_layers": 3, "radial_so2_mode": "degree_channel", "radial_so2_rank": 1, - "n_focus": 1, + "n_focus": 2, "focus_dim": 0, "n_atten_head": 1, - "atten_f_mix": false, - "atten_v_proj": false, - "atten_o_proj": false, "ffn_neurons": 0, - "grid_mlp": false, - "ffn_blocks": 1, + "ffn_so3_grid": true, + "grid_mlp": [ + false, + false, + false + ], + "grid_branch": [ + 1, + 1, + 1 + ], + "ffn_blocks": 2, "sandwich_norm": [ false, true, true, false ], - "mlp_bias": false, - "layer_scale": false, - "full_attn_res": "none", - "block_attn_res": "none", - "s2_activation": [ - false, - true - ], - "lebedev_quadrature": true, - "activation_function": "silu", - "glu_activation": true, + "message_node_so3": true, "use_amp": true, "precision": "float32", "seed": 42 @@ -63,7 +49,6 @@ "neuron": [ 0 ], - "activation_function": "silu", "precision": "float32", "seed": 42 }, @@ -92,9 +77,6 @@ }, "optimizer": { "type": "HybridMuon", - "muon_mode": "slice", - "magma_muon": true, - "lr_adjust": 0.0, "weight_decay": 0.001 }, "training": { @@ -134,5 +116,9 @@ "profiling_file": "timeline.json", "zero_stage": 1, "seed": 42 + }, + "validating": { + "compiled_infer": false, + "tf32_infer": false } } diff --git a/examples/water/dpa4/input_dens.json b/examples/water/dpa4/input_dens.json index 0379f5040e..266f68eba7 100644 --- a/examples/water/dpa4/input_dens.json +++ b/examples/water/dpa4/input_dens.json @@ -7,54 +7,32 @@ "H" ], "descriptor": { - "type": "DPA4", "sel": 120, "rcut": 6.0, - "env_exp": [ - 7, - 5 - ], - "channels": 64, + "channels": 32, "n_radial": 16, - "radial_mlp": [ - 0 - ], "use_env_seed": true, - "random_gamma": true, "lmax": 3, "mmax": 1, - "n_blocks": 3, - "so2_layers": 4, - "so2_norm": false, - "so2_attn_res": "none", + "n_blocks": 2, + "so2_layers": 3, "radial_so2_mode": "degree_channel", "radial_so2_rank": 1, - "n_focus": 1, + "n_focus": 2, "focus_dim": 0, "n_atten_head": 1, - "atten_f_mix": false, - "atten_v_proj": false, - "atten_o_proj": false, "ffn_neurons": 0, + "ffn_so3_grid": true, "grid_mlp": false, - "ffn_blocks": 1, + "grid_branch": 1, + "ffn_blocks": 2, "sandwich_norm": [ false, true, true, false ], - "mlp_bias": false, - "layer_scale": false, - "full_attn_res": "none", - "block_attn_res": "none", - "s2_activation": [ - false, - true - ], - "lebedev_quadrature": true, - "activation_function": "silu", - "glu_activation": true, + "message_node_so3": true, "use_amp": true, "precision": "float32", "seed": 42 @@ -63,7 +41,6 @@ "neuron": [ 0 ], - "activation_function": "silu", "precision": "float32", "seed": 42 }, @@ -93,9 +70,6 @@ }, "optimizer": { "type": "HybridMuon", - "muon_mode": "slice", - "magma_muon": true, - "lr_adjust": 0.0, "weight_decay": 0.001 }, "training": { diff --git a/examples/water/dpa4/input_multitask.json b/examples/water/dpa4/input_multitask.json index 8abe0e3ffb..15e97cdaba 100644 --- a/examples/water/dpa4/input_multitask.json +++ b/examples/water/dpa4/input_multitask.json @@ -9,54 +9,32 @@ "H" ], "descriptor": { - "type": "DPA4", "sel": 120, "rcut": 6.0, - "env_exp": [ - 7, - 5 - ], - "channels": 64, + "channels": 32, "n_radial": 16, - "radial_mlp": [ - 0 - ], "use_env_seed": true, - "random_gamma": true, "lmax": 3, "mmax": 1, - "n_blocks": 3, - "so2_layers": 4, - "so2_norm": false, - "so2_attn_res": "none", + "n_blocks": 2, + "so2_layers": 3, "radial_so2_mode": "degree_channel", "radial_so2_rank": 1, - "n_focus": 1, + "n_focus": 2, "focus_dim": 0, "n_atten_head": 1, - "atten_f_mix": false, - "atten_v_proj": false, - "atten_o_proj": false, "ffn_neurons": 0, + "ffn_so3_grid": true, "grid_mlp": false, - "ffn_blocks": 1, + "grid_branch": 1, + "ffn_blocks": 2, "sandwich_norm": [ false, true, true, false ], - "mlp_bias": false, - "layer_scale": false, - "full_attn_res": "none", - "block_attn_res": "none", - "s2_activation": [ - false, - true - ], - "lebedev_quadrature": true, - "activation_function": "silu", - "glu_activation": true, + "message_node_so3": true, "use_amp": true, "precision": "float32", "seed": 42 @@ -66,7 +44,6 @@ "neuron": [ 0 ], - "activation_function": "silu", "precision": "float32", "dim_case_embd": 2, "case_film_embd": true, @@ -136,9 +113,6 @@ }, "optimizer": { "type": "HybridMuon", - "muon_mode": "slice", - "magma_muon": true, - "lr_adjust": 0.0, "weight_decay": 0.001 }, "training": { diff --git a/examples/water/dpa4/lora_ft.json b/examples/water/dpa4/lora_ft.json index 5832fdf837..7214133a77 100644 --- a/examples/water/dpa4/lora_ft.json +++ b/examples/water/dpa4/lora_ft.json @@ -7,54 +7,32 @@ "H" ], "descriptor": { - "type": "DPA4", "sel": 120, "rcut": 6.0, - "env_exp": [ - 7, - 5 - ], - "channels": 64, + "channels": 32, "n_radial": 16, - "radial_mlp": [ - 0 - ], "use_env_seed": true, - "random_gamma": true, "lmax": 3, "mmax": 1, - "n_blocks": 3, - "so2_layers": 4, - "so2_norm": false, - "so2_attn_res": "none", + "n_blocks": 2, + "so2_layers": 3, "radial_so2_mode": "degree_channel", "radial_so2_rank": 1, - "n_focus": 1, + "n_focus": 2, "focus_dim": 0, "n_atten_head": 1, - "atten_f_mix": false, - "atten_v_proj": false, - "atten_o_proj": false, "ffn_neurons": 0, + "ffn_so3_grid": true, "grid_mlp": false, - "ffn_blocks": 1, + "grid_branch": 1, + "ffn_blocks": 2, "sandwich_norm": [ false, true, true, false ], - "mlp_bias": false, - "layer_scale": false, - "full_attn_res": "none", - "block_attn_res": "none", - "s2_activation": [ - false, - true - ], - "lebedev_quadrature": true, - "activation_function": "silu", - "glu_activation": true, + "message_node_so3": true, "use_amp": true, "precision": "float32", "seed": 42 @@ -63,7 +41,6 @@ "neuron": [ 0 ], - "activation_function": "silu", "precision": "float32", "seed": 42 }, @@ -94,9 +71,6 @@ }, "optimizer": { "type": "HybridMuon", - "muon_mode": "slice", - "magma_muon": true, - "lr_adjust": 0.0, "weight_decay": 0.001 }, "training": { diff --git a/source/tests/pt/model/test_descriptor_sezm.py b/source/tests/pt/model/test_descriptor_sezm.py index 790d6571b7..2be18bfcca 100644 --- a/source/tests/pt/model/test_descriptor_sezm.py +++ b/source/tests/pt/model/test_descriptor_sezm.py @@ -16,6 +16,7 @@ SO2Linear, WignerDCalculator, build_edge_quaternion, + build_gie_zonal_index, build_m_major_l_index, quaternion_multiply, quaternion_to_rotation_matrix, @@ -184,10 +185,22 @@ def test_forward_with_descriptor_variants(self) -> None: s2_activation=[False, True], lebedev_quadrature=[False, True], ), + "message_node_s2": _descriptor_kwargs( + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + message_node_s2=True, + ), "gaussian_basis": _descriptor_kwargs( channels=4, basis_type="gaussian", ), + "extra_node_l": _descriptor_kwargs( + channels=4, + extra_node_l=1, + s2_activation=[False, True], + ), "radial_so2_degree": _descriptor_kwargs( channels=4, n_focus=2, @@ -225,6 +238,7 @@ def test_forward_with_attention_variants(self) -> None: precision="float32", seed=123, s2_activation=[False, True], + extra_node_l=1, ), "mixed_so2_attention": _attention_descriptor_kwargs( precision="float32", @@ -418,6 +432,26 @@ def test_serialization_deserialization(self) -> None: s2_activation=[False, True], lebedev_quadrature=[False, True], ), + "message_node_s2": _descriptor_kwargs( + precision="float32", + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + n_radial=3, + radial_mlp=[6], + ffn_neurons=8, + message_node_s2=True, + ), + "extra_node_l": _descriptor_kwargs( + precision="float32", + channels=4, + n_radial=3, + radial_mlp=[6], + ffn_neurons=8, + extra_node_l=1, + s2_activation=[False, True], + ), "radial_so2_degree": _descriptor_kwargs( precision="float32", channels=4, @@ -530,28 +564,6 @@ def test_charge_spin_sparse_edge_conditioning(self) -> None: self.assertFalse(torch.allclose(desc_ref, desc_shifted)) torch.testing.assert_close(desc_ref, desc_restored, atol=1e-6, rtol=1e-6) - def test_plain_descriptor_deserializes_without_condition_config(self) -> None: - """Plain descriptors should not depend on charge/spin condition fields.""" - coord, atype, nlist = _tiny_two_atom_system(self.device, dtype=torch.float32) - extended_coord = coord.reshape(1, -1) - model = DescrptSeZM(**_descriptor_kwargs(seed=123)) - self.assertTrue( - all("charge_spin_embedding" not in key for key in model.state_dict()) - ) - data = model.serialize() - data["config"].pop("add_chg_spin_ebd", None) - data["config"].pop("default_chg_spin", None) - - restored = DescrptSeZM.deserialize(data) - desc_ref, *_ = model(extended_coord, atype, nlist) - desc_new, *_ = restored(extended_coord, atype, nlist) - - self.assertFalse(restored.add_chg_spin_ebd) - self.assertTrue( - all("charge_spin_embedding" not in key for key in restored.state_dict()) - ) - torch.testing.assert_close(desc_ref, desc_new, atol=1e-6, rtol=1e-6) - def test_seed_reproducibility(self) -> None: """Test that fixed seed produces identical model initialization.""" for prec in ["float64", "float32", "bfloat16"]: @@ -696,9 +708,11 @@ def _extract_l_block( def test_orthogonality(self) -> None: """Test D @ D^T = I for random quaternions.""" - for dtype, lmax in itertools.product([torch.float64, torch.float32], [1, 3, 6]): + for dtype, lmax in itertools.product( + [torch.float64, torch.float32], [1, 3, 5, 6, 8, 10] + ): atol, rtol = self._get_tols(dtype) - wigner = WignerDCalculator(lmax=lmax, dtype=dtype) + wigner = WignerDCalculator(lmax=lmax, dtype=dtype).to(self.device) edge_quat = _random_quaternion(self.batch, device=self.device, dtype=dtype) D_full, Dt_full = wigner(edge_quat) @@ -721,10 +735,12 @@ def test_orthogonality(self) -> None: def test_group_property(self) -> None: """Test group property in quaternion composition order.""" - for dtype, lmax in itertools.product([torch.float64, torch.float32], [1, 3, 6]): + for dtype, lmax in itertools.product( + [torch.float64, torch.float32], [1, 3, 5, 6, 8, 10] + ): atol = 1e-10 if dtype == torch.float64 else 5e-4 rtol = 1e-10 if dtype == torch.float64 else 5e-4 - wigner = WignerDCalculator(lmax=lmax, dtype=dtype) + wigner = WignerDCalculator(lmax=lmax, dtype=dtype).to(self.device) q1 = _random_quaternion(self.batch, device=self.device, dtype=dtype) q2 = _random_quaternion(self.batch, device=self.device, dtype=dtype) @@ -761,7 +777,7 @@ def test_l1_matches_vector_representation(self) -> None: edge_quat = _random_quaternion(self.batch, device=self.device, dtype=dtype) rot = quaternion_to_rotation_matrix(edge_quat) - wigner = WignerDCalculator(lmax=1, dtype=dtype) + wigner = WignerDCalculator(lmax=1, dtype=dtype).to(self.device) D_full, Dt_full = wigner(edge_quat) D1 = self._extract_l_block(D_full, 1) Dt1 = self._extract_l_block(Dt_full, 1) @@ -782,10 +798,61 @@ def test_l1_matches_vector_representation(self) -> None: msg=f"l=1 transpose block mismatch for WignerDCalculator, dtype={dtype}", ) + def test_zonal_matches_full_wigner_gather(self) -> None: + """Test GIE zonal coupling matches the full Wigner-D gather.""" + for dtype, lmax in itertools.product( + [torch.float64, torch.float32], [0, 1, 3, 5, 6, 8, 10] + ): + atol, rtol = self._get_tols(dtype) + wigner = WignerDCalculator(lmax=lmax, dtype=dtype).to(self.device) + if lmax < 11: + self.assertFalse(hasattr(wigner, "poly_coeffs")) + edge_quat = _random_quaternion(self.batch, device=self.device, dtype=dtype) + _, dt_full = wigner(edge_quat) + node_row_index, node_zonal_m0_col_index, _ = build_gie_zonal_index( + lmax, + device=self.device, + ) + expected = dt_full[:, node_row_index, node_zonal_m0_col_index] + actual = wigner.forward_zonal(edge_quat) + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + if lmax >= 2: + lmin = 2 + suffix_start = lmin * lmin - 1 + actual_suffix = wigner.forward_zonal(edge_quat, lmin=lmin) + torch.testing.assert_close( + actual_suffix, + expected[:, suffix_start:], + atol=atol, + rtol=rtol, + ) + + def test_special_blocks_match_generic_reference_on_cpu(self) -> None: + """Test specialized l=2..10 Wigner blocks against the generic reference.""" + device = torch.device("cpu") + dtype = torch.float64 + edge_quat = _random_quaternion(4, device=device, dtype=dtype) + wigner = WignerDCalculator(lmax=10, dtype=dtype).to(device) + d_full, _ = wigner(edge_quat) + reference = WignerDCalculator._compute_generic_reference_blocks( + edge_quat, + lmax=10, + dtype=dtype, + device=device, + ) + for degree in range(2, 11): + torch.testing.assert_close( + self._extract_l_block(d_full, degree), + reference[degree], + atol=1.0e-12, + rtol=1.0e-10, + msg=f"Special Wigner block mismatch for l={degree}", + ) + def test_pole_path_gradient_matches_finite_difference(self) -> None: """Check one pole-crossing Wigner probe against finite differences.""" for dtype in [torch.float64, torch.float32]: - wigner = WignerDCalculator(lmax=6, dtype=dtype) + wigner = WignerDCalculator(lmax=6, dtype=dtype).to(self.device) atol = 5.0e-8 if dtype == torch.float64 else 2.0e-6 rtol = 1.0e-6 if dtype == torch.float64 else 2.0e-4 for sign in [1.0, -1.0]: @@ -823,7 +890,7 @@ def test_pole_path_gradient_matches_finite_difference(self) -> None: def test_y_crossing_overlap_has_no_large_wigner_jump(self) -> None: """Check chart-overlap continuity for a path that crosses y=0.""" for dtype in [torch.float64, torch.float32]: - wigner = WignerDCalculator(lmax=4, dtype=dtype) + wigner = WignerDCalculator(lmax=4, dtype=dtype).to(self.device) max_allowed = 1.0e-2 if dtype == torch.float64 else 1.5e-2 y_vals = torch.tensor( [-1.0e-3, -5.0e-4, -1.0e-4, 0.0, 1.0e-4, 5.0e-4, 1.0e-3], diff --git a/source/tests/pt/model/test_descriptor_sezm_grid_projection.py b/source/tests/pt/model/test_descriptor_sezm_grid_projection.py new file mode 100644 index 0000000000..9b1a0f6c6f --- /dev/null +++ b/source/tests/pt/model/test_descriptor_sezm_grid_projection.py @@ -0,0 +1,1023 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +from deepmd.pt.model.descriptor.sezm_nn import ( + S2GridNet, + S2GridProjector, + SO3GridNet, + SO3GridProjector, + WignerDCalculator, + build_edge_quaternion, + build_m_major_index, + load_lebedev_rule, + quaternion_multiply, + quaternion_z_rotation, + resolve_s2_grid_resolution, +) + + +def _random_quaternion( + n_batch: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Sample normalized quaternions in ``(w, x, y, z)`` order.""" + q = torch.randn(n_batch, 4, device=device, dtype=dtype) + return q / torch.sqrt(torch.sum(q * q, dim=-1, keepdim=True)) + + +def _rotate_ndfc(x: torch.Tensor, d_matrix: torch.Tensor) -> torch.Tensor: + """Rotate coefficient-layout tensors with shape ``(N, D, F, C)``.""" + return torch.einsum("nij,njfc->nifc", d_matrix, x) + + +def _rotate_nfdc(x: torch.Tensor, d_matrix: torch.Tensor) -> torch.Tensor: + """Rotate coefficient-layout tensors with shape ``(N, F, D, C)``.""" + return torch.einsum("nij,nfjc->nfic", d_matrix, x) + + +def _max_abs_equivariance_error(lhs: torch.Tensor, rhs: torch.Tensor) -> float: + """Compute the maximum absolute equivariance error.""" + return float(torch.max(torch.abs(lhs - rhs)).item()) + + +def _legal_so3_frame_mask(projector: SO3GridProjector) -> torch.Tensor: + mask = torch.ones( + projector.coeff_dim, + dtype=torch.bool, + device=projector.to_grid_mat.device, + ) + n_frames = projector.n_frames + for degree in range(projector.lmax + 1): + for m_order in range(-degree, degree + 1): + packed_idx = degree * degree + degree + m_order + for frame_pos, frame_order in enumerate(projector.frame_set): + flat_idx = packed_idx * n_frames + frame_pos + if flat_idx >= projector.coeff_dim: + continue + if abs(frame_order) > degree: + mask[flat_idx] = False + return mask + + +class TestS2GridProjector(unittest.TestCase): + """Test S2 projection invariants.""" + + def setUp(self) -> None: + self.device = torch.device("cpu") + torch.manual_seed(0) + + def test_lebedev_roundtrip_preserves_bandlimited_coefficients(self) -> None: + """Lebedev quadrature should reconstruct coefficients up to lmax.""" + projector = S2GridProjector( + lmax=3, + dtype=torch.float64, + grid_resolution_list=None, + coefficient_layout="packed", + grid_method="lebedev", + ).to(self.device) + x = torch.randn( + 5, projector.coeff_dim, 3, device=self.device, dtype=torch.float64 + ) + y = projector.from_grid(projector.to_grid(x)) + torch.testing.assert_close(y, x, atol=1e-12, rtol=1e-12) + + +class TestSwiGLUS2Equivariance(unittest.TestCase): + """Test default-grid equivariance of full-m and truncated SwiGLU-S2 activations.""" + + def setUp(self) -> None: + self.device = torch.device("cpu") + torch.manual_seed(0) + + def test_default_full_m_grid_counts_keep_s2_activation_equivariant(self) -> None: + """Default full-m S2 activation grids should keep SO(3) equivariance.""" + # Each case is (lmax, full_m_grid, fp64_tol, fp32_tol). + # e3nn full_m_grid is [R_phi, R_theta] after the square-grid lift. + # Lebedev full_m_grid is [precision, n_points]. + cases_by_method = { + "e3nn": [ + (2, [8, 8], 4.20e-7, 5.00e-6), # local: fp64=3.62e-7, fp32=4.77e-7 + (3, [12, 12], 8.10e-7, 5.00e-6), # local: fp64=7.04e-7, fp32=6.86e-7 + (4, [14, 14], 9.20e-7, 5.00e-6), # local: fp64=7.97e-7, fp32=1.55e-6 + (5, [18, 18], 1.70e-6, 5.00e-6), # local: fp64=1.48e-6, fp32=1.49e-6 + (6, [20, 20], 4.80e-6, 5.00e-6), # local: fp64=4.14e-6, fp32=2.27e-6 + (7, [24, 24], 3.70e-6, 5.00e-6), # local: fp64=3.19e-6, fp32=2.03e-6 + ], + "lebedev": [ + (2, [7, 26], 1.00e-12, 5.00e-6), # local: fp64=2.31e-14, fp32=2.38e-7 + (3, [9, 38], 1.00e-12, 5.00e-6), # local: fp64=3.58e-14, fp32=3.58e-7 + (4, [13, 74], 1.00e-12, 5.00e-6), # local: fp64=5.82e-14, fp32=6.56e-7 + (5, [15, 86], 1.00e-12, 5.00e-6), # local: fp64=3.22e-14, fp32=6.56e-7 + (6, [19, 146], 1.00e-12, 5.00e-6), # local: fp64=7.99e-14, fp32=8.35e-7 + (7, [21, 170], 1.00e-12, 5.00e-6), # local: fp64=6.86e-14, fp32=8.79e-7 + ], + } + dtype_cases = [ + (torch.float64, 0), + (torch.float32, 1), + ] + n_batch = 3 + n_focus = 1 + channels = 2 + + for dtype, tolerance_index in dtype_cases: + for method, cases in cases_by_method.items(): + for lmax, expected_full_m_grid, *tolerances in cases: + with self.subTest( + method=method, + dtype=dtype, + lmax=lmax, + grid=expected_full_m_grid, + ): + self._assert_default_full_m_s2_activation_equivariance( + grid_method=method, + lmax=lmax, + expected_full_m_grid=expected_full_m_grid, + n_batch=n_batch, + n_focus=n_focus, + channels=channels, + dtype=dtype, + tolerance=tolerances[tolerance_index], + ) + + def _assert_default_full_m_s2_activation_equivariance( + self, + *, + grid_method: str, + lmax: int, + expected_full_m_grid: list[int], + n_batch: int, + n_focus: int, + channels: int, + dtype: torch.dtype, + tolerance: float, + op_type: str = "glu", + ) -> None: + """Assert full-m S2 activation equivariance for one method/dtype/lmax case.""" + torch.manual_seed(1234 + lmax) + default_grid = resolve_s2_grid_resolution( + lmax, + lmax, + method=grid_method, + ) + full_m_grid = ( + [max(default_grid), max(default_grid)] + if grid_method == "e3nn" + else default_grid + ) + self.assertEqual(full_m_grid, expected_full_m_grid) + + activation = S2GridNet( + lmax=lmax, + channels=channels, + dtype=dtype, + n_focus=n_focus, + mode="self", + op_type=op_type, + layout="ndfc", + grid_resolution_list=full_m_grid, + coefficient_layout="packed", + grid_method=grid_method, + mlp_bias=False, + trainable=False, + seed=17 + lmax, + ).to(self.device) + self.assertEqual(activation.grid_resolution_list, expected_full_m_grid) + + x = torch.randn( + n_batch, + (lmax + 1) ** 2, + n_focus, + 2 * channels, + device=self.device, + dtype=dtype, + ) + quat = _random_quaternion(n_batch, device=self.device, dtype=dtype) + d_matrix, _ = WignerDCalculator(lmax=lmax, dtype=dtype).to(self.device)(quat) + + y_rotated_input = activation(_rotate_ndfc(x, d_matrix)) + y_then_rotated = _rotate_ndfc(activation(x), d_matrix) + max_error = _max_abs_equivariance_error( + y_rotated_input, + y_then_rotated, + ) + + self.assertLessEqual(max_error, tolerance) + + def test_polynomial_grid_mlp_full_m_s2_equivariance(self) -> None: + """S2 grid MLP should keep full-m SO(3) equivariance.""" + # Each case is (lmax, full_m_grid, fp64_tol, fp32_tol). + cases_by_method = { + "e3nn": [ + (2, [8, 8], 2.00e-7, 5.00e-6), # local: fp64=7.28e-8, fp32=1.34e-7 + (3, [12, 12], 8.00e-7, 5.00e-6), # local: fp64=1.87e-7, fp32=1.34e-7 + (4, [14, 14], 8.00e-7, 5.00e-6), # local: fp64=2.17e-7, fp32=4.47e-7 + (5, [18, 18], 1.20e-6, 5.00e-6), # local: fp64=1.96e-7, fp32=5.96e-7 + (6, [20, 20], 1.20e-6, 5.00e-6), # local: fp64=9.70e-7, fp32=9.88e-7 + (7, [24, 24], 1.50e-6, 5.00e-6), # local: fp64=8.25e-7, fp32=9.24e-7 + ], + "lebedev": [ + (2, [7, 26], 1.00e-12, 5.00e-6), # local: fp64=4.05e-15, fp32=4.47e-8 + (3, [9, 38], 1.00e-12, 5.00e-6), # local: fp64=7.19e-15, fp32=7.45e-8 + (4, [13, 74], 1.00e-12, 5.00e-6), # local: fp64=1.57e-14, fp32=2.53e-7 + (5, [15, 86], 1.00e-12, 5.00e-6), # local: fp64=7.59e-15, fp32=2.12e-7 + (6, [19, 146], 1.00e-12, 5.00e-6), # local: fp64=2.21e-14, fp32=2.98e-7 + (7, [21, 170], 1.00e-12, 5.00e-6), # local: fp64=2.73e-14, fp32=6.85e-7 + ], + } + dtype_cases = [ + (torch.float64, 0), + (torch.float32, 1), + ] + n_batch = 3 + n_focus = 1 + channels = 2 + + for dtype, tolerance_index in dtype_cases: + for method, cases in cases_by_method.items(): + for lmax, expected_full_m_grid, *tolerances in cases: + with self.subTest( + method=method, + dtype=dtype, + lmax=lmax, + grid=expected_full_m_grid, + ): + self._assert_default_full_m_s2_activation_equivariance( + grid_method=method, + lmax=lmax, + expected_full_m_grid=expected_full_m_grid, + n_batch=n_batch, + n_focus=n_focus, + channels=channels, + dtype=dtype, + tolerance=tolerances[tolerance_index], + op_type="mlp", + ) + + def test_default_mmax_truncated_grid_counts_keep_s2_activation_z_equivariant( + self, + ) -> None: + """Default mmax-truncated S2 activation grids should keep z-equivariance.""" + # Each case is (lmax, mmax, truncated_grid, fp64_tol, fp32_tol). + # e3nn truncated_grid is [R_phi, R_theta] used by the m-major path. + # Lebedev truncated_grid is [precision, n_points]. + cases_by_method = { + "e3nn": { + 1: [ + (2, [6, 8], 2.80e-7, 5.00e-6), # local: fp64=2.36e-7, fp32=3.58e-7 + (3, [6, 12], 1.50e-7, 5.00e-6), # local: fp64=1.22e-7, fp32=5.96e-7 + (4, [6, 14], 1.33e-6, 5.00e-6), # local: fp64=1.12e-6, fp32=9.54e-7 + (5, [6, 18], 1.30e-7, 5.00e-6), # local: fp64=1.10e-7, fp32=1.43e-6 + (6, [6, 20], 9.00e-7, 5.00e-6), # local: fp64=7.64e-7, fp32=1.91e-6 + (7, [6, 24], 2.60e-7, 5.00e-6), # local: fp64=2.17e-7, fp32=1.91e-6 + ], + 2: [ + (2, [8, 8], 4.70e-7, 5.00e-6), # local: fp64=4.01e-7, fp32=8.34e-7 + (3, [8, 12], 7.00e-7, 5.00e-6), # local: fp64=5.99e-7, fp32=8.34e-7 + (4, [8, 14], 7.00e-7, 5.00e-6), # local: fp64=6.02e-7, fp32=1.67e-6 + (5, [8, 18], 1.40e-6, 5.00e-6), # local: fp64=1.19e-6, fp32=1.55e-6 + (6, [8, 20], 1.55e-6, 5.00e-6), # local: fp64=1.33e-6, fp32=2.15e-6 + (7, [8, 24], 1.65e-6, 5.00e-6), # local: fp64=1.41e-6, fp32=2.62e-6 + ], + }, + "lebedev": { + 1: [ + ( + 2, + [7, 26], + 1.00e-12, + 5.00e-6, + ), # local: fp64=2.31e-14, fp32=2.38e-7 + ( + 3, + [9, 38], + 1.00e-12, + 5.00e-6, + ), # local: fp64=3.55e-14, fp32=2.98e-7 + ( + 4, + [13, 74], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.04e-13, fp32=9.54e-7 + ( + 5, + [15, 86], + 1.00e-12, + 5.00e-6, + ), # local: fp64=9.34e-14, fp32=7.15e-7 + ( + 6, + [19, 146], + 1.00e-12, + 5.00e-6, + ), # local: fp64=8.56e-14, fp32=2.15e-6 + ( + 7, + [21, 170], + 1.00e-12, + 5.00e-6, + ), # local: fp64=2.08e-13, fp32=3.34e-6 + ], + 2: [ + ( + 2, + [7, 26], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.50e-14, fp32=2.38e-7 + ( + 3, + [9, 38], + 1.00e-12, + 5.00e-6, + ), # local: fp64=5.71e-14, fp32=3.58e-7 + ( + 4, + [13, 74], + 1.00e-12, + 5.00e-6, + ), # local: fp64=9.15e-14, fp32=5.96e-7 + ( + 5, + [15, 86], + 1.00e-12, + 5.00e-6, + ), # local: fp64=7.83e-14, fp32=4.77e-7 + ( + 6, + [19, 146], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.29e-13, fp32=9.54e-7 + ( + 7, + [21, 170], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.56e-13, fp32=1.43e-6 + ], + }, + } + dtype_cases = [ + (torch.float64, 0), + (torch.float32, 1), + ] + n_batch = 3 + n_focus = 2 + channels = 2 + + for dtype, tolerance_index in dtype_cases: + for method, cases_by_mmax in cases_by_method.items(): + for mmax, cases in cases_by_mmax.items(): + for lmax, expected_truncated_grid, *tolerances in cases: + with self.subTest( + method=method, + dtype=dtype, + lmax=lmax, + mmax=mmax, + grid=expected_truncated_grid, + ): + self._assert_default_mmax_truncated_grid_z_equivariance( + grid_method=method, + lmax=lmax, + mmax=mmax, + expected_truncated_grid=expected_truncated_grid, + n_batch=n_batch, + n_focus=n_focus, + channels=channels, + dtype=dtype, + tolerance=tolerances[tolerance_index], + ) + + def _assert_default_mmax_truncated_grid_z_equivariance( + self, + *, + grid_method: str, + lmax: int, + mmax: int, + expected_truncated_grid: list[int], + n_batch: int, + n_focus: int, + channels: int, + dtype: torch.dtype, + tolerance: float, + op_type: str = "glu", + ) -> None: + """Assert mmax-truncated S2 activation z-equivariance for one case.""" + torch.manual_seed(2234 + lmax + 100 * mmax) + truncated_grid = resolve_s2_grid_resolution( + lmax, + mmax, + method=grid_method, + ) + self.assertEqual(truncated_grid, expected_truncated_grid) + + activation = S2GridNet( + lmax=lmax, + mmax=mmax, + channels=channels, + dtype=dtype, + n_focus=n_focus, + mode="self", + op_type=op_type, + layout="nfdc", + grid_resolution_list=truncated_grid, + coefficient_layout="m_major", + grid_method=grid_method, + mlp_bias=False, + trainable=False, + seed=27 + lmax + 100 * mmax, + ).to(self.device) + self.assertEqual(activation.grid_resolution_list, expected_truncated_grid) + + coeff_index = build_m_major_index(lmax, mmax, device=self.device) + x = torch.randn( + n_batch, + n_focus, + int(coeff_index.numel()), + 2 * channels, + device=self.device, + dtype=dtype, + ) + gamma = torch.randn(n_batch, device=self.device, dtype=dtype) + quaternion = quaternion_z_rotation(gamma) + d_matrix, _ = WignerDCalculator(lmax=lmax, dtype=dtype).to(self.device)( + quaternion + ) + d_matrix_reduced = d_matrix.index_select(1, coeff_index).index_select( + 2, + coeff_index, + ) + + y_rotated_input = activation(_rotate_nfdc(x, d_matrix_reduced)) + y_then_rotated = _rotate_nfdc(activation(x), d_matrix_reduced) + max_error = _max_abs_equivariance_error( + y_rotated_input, + y_then_rotated, + ) + + self.assertLessEqual(max_error, tolerance) + + def test_polynomial_grid_mlp_mmax_truncated_s2_z_equivariance(self) -> None: + """S2 grid MLP should keep mmax-truncated z-equivariance.""" + # Each case is (lmax, mmax, truncated_grid, fp64_tol, fp32_tol). + cases_by_method = { + "e3nn": { + 1: [ + (2, [6, 8], 2.00e-7, 5.00e-6), # local: fp64=5.74e-8, fp32=1.19e-7 + (3, [6, 12], 2.00e-7, 5.00e-6), # local: fp64=2.16e-8, fp32=1.49e-7 + (4, [6, 14], 8.00e-7, 5.00e-6), # local: fp64=4.90e-7, fp32=2.38e-7 + (5, [6, 18], 2.00e-7, 5.00e-6), # local: fp64=3.29e-8, fp32=3.58e-7 + (6, [6, 20], 4.00e-7, 5.00e-6), # local: fp64=7.78e-8, fp32=4.17e-7 + (7, [6, 24], 4.00e-7, 5.00e-6), # local: fp64=1.14e-7, fp32=6.56e-7 + ], + 2: [ + (2, [8, 8], 2.00e-7, 5.00e-6), # local: fp64=7.34e-8, fp32=1.19e-7 + (3, [8, 12], 4.00e-7, 5.00e-6), # local: fp64=6.49e-8, fp32=6.56e-7 + (4, [8, 14], 8.00e-7, 5.00e-6), # local: fp64=1.33e-7, fp32=2.09e-7 + (5, [8, 18], 4.00e-7, 5.00e-6), # local: fp64=2.63e-7, fp32=2.40e-7 + (6, [8, 20], 8.00e-7, 5.00e-6), # local: fp64=3.10e-7, fp32=7.45e-7 + (7, [8, 24], 8.00e-7, 5.00e-6), # local: fp64=2.95e-7, fp32=2.38e-7 + ], + }, + "lebedev": { + 1: [ + ( + 2, + [7, 26], + 1.00e-12, + 5.00e-6, + ), # local: fp64=2.28e-15, fp32=7.45e-8 + ( + 3, + [9, 38], + 1.00e-12, + 5.00e-6, + ), # local: fp64=4.04e-15, fp32=1.79e-7 + ( + 4, + [13, 74], + 1.00e-12, + 5.00e-6, + ), # local: fp64=5.44e-14, fp32=2.38e-7 + ( + 5, + [15, 86], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.99e-14, fp32=2.38e-7 + ( + 6, + [19, 146], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.81e-14, fp32=9.54e-7 + ( + 7, + [21, 170], + 1.00e-12, + 5.00e-6, + ), # local: fp64=4.86e-14, fp32=3.87e-7 + ], + 2: [ + ( + 2, + [7, 26], + 1.00e-12, + 5.00e-6, + ), # local: fp64=2.84e-15, fp32=7.45e-8 + ( + 3, + [9, 38], + 1.00e-12, + 5.00e-6, + ), # local: fp64=5.33e-15, fp32=1.19e-7 + ( + 4, + [13, 74], + 1.00e-12, + 5.00e-6, + ), # local: fp64=7.45e-15, fp32=1.19e-7 + ( + 5, + [15, 86], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.68e-14, fp32=1.19e-7 + ( + 6, + [19, 146], + 1.00e-12, + 5.00e-6, + ), # local: fp64=2.62e-14, fp32=4.77e-7 + ( + 7, + [21, 170], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.98e-14, fp32=2.38e-7 + ], + }, + } + dtype_cases = [ + (torch.float64, 0), + (torch.float32, 1), + ] + n_batch = 3 + n_focus = 2 + channels = 2 + + for dtype, tolerance_index in dtype_cases: + for method, cases_by_mmax in cases_by_method.items(): + for mmax, cases in cases_by_mmax.items(): + for lmax, expected_truncated_grid, *tolerances in cases: + with self.subTest( + method=method, + dtype=dtype, + lmax=lmax, + mmax=mmax, + grid=expected_truncated_grid, + ): + self._assert_default_mmax_truncated_grid_z_equivariance( + grid_method=method, + lmax=lmax, + mmax=mmax, + expected_truncated_grid=expected_truncated_grid, + n_batch=n_batch, + n_focus=n_focus, + channels=channels, + dtype=dtype, + tolerance=tolerances[tolerance_index], + op_type="mlp", + ) + + +class TestSO3GridProjector(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2026) + self.device = torch.device("cpu") + + def test_roundtrip_preserves_legal_frame_coefficients(self) -> None: + # WignerDCalculator-based grids keep local macOS fp64 round-trip errors + # below 4.1e-13 for lmax=1..6 without a dual-basis correction. + for lmax in range(1, 7): + with self.subTest(lmax=lmax): + torch.manual_seed(8100 + lmax) + projector = SO3GridProjector( + lmax=lmax, + kmax=1, + dtype=torch.float64, + ).to(self.device) + x = torch.randn( + 2, + projector.coeff_dim, + 2, + dtype=torch.float64, + device=self.device, + ) + mask = _legal_so3_frame_mask(projector) + x[:, ~mask, :] = 0.0 + y = projector.from_grid(projector.to_grid(x)) + torch.testing.assert_close( + y[:, mask, :], + x[:, mask, :], + atol=1e-12, + rtol=1e-12, + ) + self.assertLess(float(y[:, ~mask, :].abs().max()), 1e-14) + + def test_projection_matrices_match_direct_wigner_construction(self) -> None: + projector = SO3GridProjector( + lmax=2, + kmax=1, + dtype=torch.float64, + lebedev_precision=17, + ) + points, weights = load_lebedev_rule( + projector.lebedev_precision, + dtype=torch.float64, + device=torch.device("cpu"), + ) + gamma = torch.arange( + projector.n_gamma, + dtype=torch.float64, + device=points.device, + ) * (2.0 * torch.pi / projector.n_gamma) + edge_quaternion = build_edge_quaternion(points, eps=1e-14) + edge_quaternion = edge_quaternion.repeat_interleave(projector.n_gamma, dim=0) + gamma_quaternion = quaternion_z_rotation(gamma).repeat(points.shape[0], 1) + grid_quaternion = quaternion_multiply(gamma_quaternion, edge_quaternion) + wigner_grid, _ = WignerDCalculator( + lmax=projector.lmax, + dtype=torch.float64, + ).to(grid_quaternion.device)(grid_quaternion) + wigner_grid = wigner_grid.transpose(-1, -2).contiguous() + haar_weight = weights.repeat_interleave(projector.n_gamma) / projector.n_gamma + to_grid_ref = torch.zeros_like(projector.to_grid_mat) + from_grid_ref = torch.zeros_like(projector.from_grid_mat) + for degree in range(projector.lmax + 1): + for m_order in range(-degree, degree + 1): + packed_idx = degree * degree + degree + m_order + for frame_pos, frame_order in enumerate(projector.frame_set): + flat_idx = packed_idx * projector.n_frames + frame_pos + if abs(frame_order) > degree: + continue + row = degree * degree + degree + m_order + col = degree * degree + degree + frame_order + values = wigner_grid[:, row, col] + to_grid_ref[:, flat_idx] = values + from_grid_ref[flat_idx] = (2 * degree + 1) * haar_weight * values + torch.testing.assert_close(projector.to_grid_mat, to_grid_ref) + torch.testing.assert_close(projector.from_grid_mat, from_grid_ref) + + def test_k_zero_slice_matches_wigner_zonal_convention(self) -> None: + lmax = 6 + projector = SO3GridProjector(lmax=lmax, kmax=0, dtype=torch.float64) + points, _ = load_lebedev_rule( + projector.lebedev_precision, + dtype=torch.float64, + device=torch.device("cpu"), + ) + edge_quaternion = build_edge_quaternion(points, eps=1e-14) + zonal = ( + WignerDCalculator(lmax=lmax, dtype=torch.float64) + .to(edge_quaternion.device) + .forward_zonal(edge_quaternion, lmin=1) + ) + torch.testing.assert_close( + projector.to_grid_mat[:, 0], torch.ones_like(points[:, 0]) + ) + torch.testing.assert_close( + projector.to_grid_mat[:, 1:], zonal, atol=1e-14, rtol=1e-14 + ) + + def test_quadratic_gamma_rule_resolves_kmax_two_products(self) -> None: + projector = SO3GridProjector(lmax=2, kmax=2, dtype=torch.float64) + self.assertEqual(projector.n_gamma, 7) + + +class TestSO3GridNet(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2027) + self.device = torch.device("cpu") + + def test_swiglu_so3_grid_net_equivariance(self) -> None: + # Each case is (lmax, fp64_tol, fp32_tol). Observed reference errors: + # fp64: [7.58e-14, 1.50e-13, 1.95e-13, 4.70e-13, 2.84e-13, 7.04e-13] + # fp32: [2.99e-7, 2.15e-6, 1.55e-6, 2.31e-6, 4.30e-6, 5.13e-6] + cases = [ + (1, 2e-13, 6e-7), + (2, 3e-13, 3e-6), + (3, 4e-13, 3e-6), + (4, 8e-13, 4e-6), + (5, 6e-13, 6e-6), + (6, 1e-12, 7e-6), + ] + channels = 2 + dtype_cases = [ + (torch.float64, 1), + (torch.float32, 2), + ] + for dtype, tolerance_index in dtype_cases: + for case in cases: + lmax = case[0] + tolerance = case[tolerance_index] + with self.subTest(dtype=dtype, lmax=lmax): + torch.manual_seed( + 7100 + lmax + (0 if dtype is torch.float64 else 100) + ) + net = SO3GridNet( + lmax=lmax, + kmax=1, + channels=channels, + n_focus=1, + mode="self", + op_type="glu", + dtype=dtype, + layout="ndfc", + trainable=False, + ).to(self.device) + x = torch.randn( + 2, + (lmax + 1) ** 2, + 1, + net.query_channels, + dtype=dtype, + device=self.device, + ) + quat = _random_quaternion(2, dtype=dtype, device=self.device) + d_matrix, _ = WignerDCalculator(lmax=lmax, dtype=dtype).to( + self.device + )(quat) + y_rotated_input = net(_rotate_ndfc(x, d_matrix)) + y_then_rotated = _rotate_ndfc(net(x), d_matrix) + torch.testing.assert_close( + y_rotated_input, + y_then_rotated, + atol=tolerance, + rtol=tolerance, + ) + + def test_polynomial_grid_mlp_so3_grid_net_equivariance(self) -> None: + for lmax in range(1, 5): + with self.subTest(lmax=lmax): + torch.manual_seed(8200 + lmax) + net = SO3GridNet( + lmax=lmax, + kmax=1, + channels=2, + n_focus=1, + mode="self", + op_type="mlp", + dtype=torch.float64, + layout="ndfc", + trainable=False, + ).to(self.device) + x = torch.randn( + 2, + (lmax + 1) ** 2, + 1, + net.query_channels, + dtype=torch.float64, + device=self.device, + ) + quat = _random_quaternion(2, dtype=torch.float64, device=self.device) + d_matrix, _ = WignerDCalculator(lmax=lmax, dtype=torch.float64).to( + self.device + )(quat) + y_rotated_input = net(_rotate_ndfc(x, d_matrix)) + y_then_rotated = _rotate_ndfc(net(x), d_matrix) + torch.testing.assert_close( + y_rotated_input, + y_then_rotated, + atol=1e-12, + rtol=1e-12, + ) + + def test_scalar_router_grid_branch_so3_grid_net_equivariance(self) -> None: + for lmax in range(1, 5): + with self.subTest(lmax=lmax): + torch.manual_seed(8300 + lmax) + net = SO3GridNet( + lmax=lmax, + kmax=1, + channels=2, + n_focus=1, + mode="self", + op_type="branch", + dtype=torch.float64, + layout="ndfc", + grid_branches=2, + trainable=False, + ).to(self.device) + x = torch.randn( + 2, + (lmax + 1) ** 2, + 1, + net.query_channels, + dtype=torch.float64, + device=self.device, + ) + quat = _random_quaternion(2, dtype=torch.float64, device=self.device) + d_matrix, _ = WignerDCalculator(lmax=lmax, dtype=torch.float64).to( + self.device + )(quat) + y_rotated_input = net(_rotate_ndfc(x, d_matrix)) + y_then_rotated = _rotate_ndfc(net(x), d_matrix) + torch.testing.assert_close( + y_rotated_input, + y_then_rotated, + atol=1e-12, + rtol=1e-12, + ) + + def test_kmax_two_quadratic_grid_ops_are_equivariant(self) -> None: + for op_type in ["glu", "mlp", "branch"]: + with self.subTest(op_type=op_type): + torch.manual_seed(8400) + net = SO3GridNet( + lmax=2, + kmax=2, + channels=2, + n_focus=1, + mode="self", + op_type=op_type, + dtype=torch.float64, + layout="ndfc", + grid_branches=2, + trainable=False, + ).to(self.device) + x = torch.randn( + 2, + 9, + 1, + net.query_channels, + dtype=torch.float64, + device=self.device, + ) + quat = _random_quaternion(2, dtype=torch.float64, device=self.device) + d_matrix, _ = WignerDCalculator(lmax=2, dtype=torch.float64).to( + self.device + )(quat) + y_rotated_input = net(_rotate_ndfc(x, d_matrix)) + y_then_rotated = _rotate_ndfc(net(x), d_matrix) + torch.testing.assert_close( + y_rotated_input, + y_then_rotated, + atol=1e-12, + rtol=1e-12, + ) + + +class TestSO3CounterExample(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2028) + self.device = torch.device("cpu") + self.l1_to_cartesian = torch.tensor( + [[0.0, -1.0, 0.0], [0.0, 0.0, -1.0], [1.0, 0.0, 0.0]], + device=self.device, + dtype=torch.float64, + ) + + def test_so3_features_span_odd_targets_but_s2_features_do_not(self) -> None: + n_sample = 2048 + channels = 48 + x = torch.randn(n_sample, 3, 3, dtype=torch.float64, device=self.device) + target, target_det = self._odd_targets_from_l1_coefficients(x) + + s2_projector = S2GridProjector( + lmax=2, + dtype=torch.float64, + grid_method="lebedev", + grid_resolution_list=[17, 110], + ) + so3_projector = SO3GridProjector( + lmax=2, + kmax=1, + dtype=torch.float64, + lebedev_precision=17, + ) + s2_features_by_m = self._s2_quadratic_features(x[:, :2], s2_projector, channels) + so3_features_by_m = self._so3_quadratic_features( + x[:, :2], + so3_projector, + channels, + ) + s2_residual = self._best_linear_residual( + s2_features_by_m.reshape(n_sample * 3, channels), + target, + ) + so3_residual = self._best_linear_residual( + so3_features_by_m.reshape(n_sample * 3, so3_projector.n_frames * channels), + target, + ) + s2_det_features = self._couple_l1_features_to_scalar(s2_features_by_m, x[:, 2]) + so3_det_features = self._couple_l1_features_to_scalar( + so3_features_by_m, + x[:, 2], + ) + s2_det_residual = self._best_linear_residual(s2_det_features, target_det) + so3_det_residual = self._best_linear_residual(so3_det_features, target_det) + self.assertGreater(s2_residual, 0.9) + self.assertLess(so3_residual, 0.35) + self.assertGreater(s2_det_residual, 0.9) + self.assertLess(so3_det_residual, 0.35) + + def _s2_quadratic_features( + self, + x: torch.Tensor, + projector: S2GridProjector, + channels: int, + ) -> torch.Tensor: + weight = torch.randn(2, 2 * channels, dtype=x.dtype, device=x.device) + coeff = x.new_zeros(x.shape[0], projector.coeff_dim, 2 * channels) + coeff[:, 1:4, :] = torch.einsum("bmi,ic->bmc", x.transpose(1, 2), weight) + grid = projector.to_grid(coeff) + grid_a, grid_b = grid.chunk(2, dim=-1) + out = projector.from_grid(grid_a * grid_b) + return out[:, 1:4, :] + + def _so3_quadratic_features( + self, + x: torch.Tensor, + projector: SO3GridProjector, + channels: int, + ) -> torch.Tensor: + n_frames = projector.n_frames + weight = torch.randn( + 2, + 2 * n_frames * channels, + dtype=x.dtype, + device=x.device, + ) + coeff = x.new_zeros(x.shape[0], projector.coeff_dim, 2 * channels) + mixed = torch.einsum("bmi,ic->bmc", x.transpose(1, 2), weight) + for local_m in range(3): + packed_idx = 1 + local_m + start = packed_idx * n_frames + stop = start + n_frames + coeff[:, start:stop, :] = mixed[:, local_m, :].reshape( + x.shape[0], + n_frames, + 2 * channels, + ) + grid = projector.to_grid(coeff) + grid_a, grid_b = grid.chunk(2, dim=-1) + out = projector.from_grid(grid_a * grid_b) + rows = [] + for local_m in range(3): + packed_idx = 1 + local_m + rows.extend(range(packed_idx * n_frames, (packed_idx + 1) * n_frames)) + return out[:, rows, :].reshape(x.shape[0], 3, n_frames * channels) + + def _couple_l1_features_to_scalar( + self, + features: torch.Tensor, + vector: torch.Tensor, + ) -> torch.Tensor: + features_cartesian = torch.einsum( + "bmp,mi->bip", + features, + self.l1_to_cartesian.to(dtype=features.dtype, device=features.device), + ) + vector_cartesian = self._l1_coefficients_to_cartesian(vector) + return torch.einsum("bip,bi->bp", features_cartesian, vector_cartesian) + + def _odd_targets_from_l1_coefficients( + self, + coeff: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + first = self._l1_coefficients_to_cartesian(coeff[:, 0]) + second = self._l1_coefficients_to_cartesian(coeff[:, 1]) + third = self._l1_coefficients_to_cartesian(coeff[:, 2]) + cross_cartesian = torch.linalg.cross(first, second, dim=-1) + cross_coeff = self._cartesian_to_l1_coefficients(cross_cartesian) + determinant = torch.sum(cross_cartesian * third, dim=-1) + return cross_coeff, determinant + + def _l1_coefficients_to_cartesian(self, coeff: torch.Tensor) -> torch.Tensor: + return coeff @ self.l1_to_cartesian.to(dtype=coeff.dtype, device=coeff.device) + + def _cartesian_to_l1_coefficients(self, vector: torch.Tensor) -> torch.Tensor: + return ( + vector @ self.l1_to_cartesian.to(dtype=vector.dtype, device=vector.device).T + ) + + def _best_linear_residual( + self, + features: torch.Tensor, + target: torch.Tensor, + ) -> float: + y = target.reshape(-1, 1) + solution = torch.linalg.lstsq(features, y).solution + residual = features @ solution - y + return float(residual.norm() / y.norm().clamp_min(1e-30)) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py b/source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py deleted file mode 100644 index b6a6acbe4e..0000000000 --- a/source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py +++ /dev/null @@ -1,384 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import unittest - -import torch - -from deepmd.pt.model.descriptor.sezm_nn import ( - S2GridProjector, - SwiGLUS2Activation, - WignerDCalculator, - build_m_major_index, - quaternion_z_rotation, - resolve_s2_grid_resolution, -) - - -def _random_quaternion( - n_batch: int, - *, - device: torch.device, - dtype: torch.dtype, -) -> torch.Tensor: - """Sample normalized quaternions in ``(w, x, y, z)`` order.""" - q = torch.randn(n_batch, 4, device=device, dtype=dtype) - return q / torch.sqrt(torch.sum(q * q, dim=-1, keepdim=True)) - - -def _rotate_ndfc(x: torch.Tensor, d_matrix: torch.Tensor) -> torch.Tensor: - """Rotate coefficient-layout tensors with shape ``(N, D, F, C)``.""" - return torch.einsum("nij,njfc->nifc", d_matrix, x) - - -def _rotate_nfdc(x: torch.Tensor, d_matrix: torch.Tensor) -> torch.Tensor: - """Rotate coefficient-layout tensors with shape ``(N, F, D, C)``.""" - return torch.einsum("nij,nfjc->nfic", d_matrix, x) - - -def _max_abs_equivariance_error(lhs: torch.Tensor, rhs: torch.Tensor) -> float: - """Compute the maximum absolute equivariance error.""" - return float(torch.max(torch.abs(lhs - rhs)).item()) - - -class TestS2GridProjector(unittest.TestCase): - """Test S2 projection invariants.""" - - def setUp(self) -> None: - self.device = torch.device("cpu") - torch.manual_seed(0) - - def test_lebedev_roundtrip_preserves_bandlimited_coefficients(self) -> None: - """Lebedev quadrature should reconstruct coefficients up to lmax.""" - projector = S2GridProjector( - lmax=3, - dtype=torch.float64, - grid_resolution_list=None, - coefficient_layout="packed", - grid_method="lebedev", - ).to(self.device) - x = torch.randn( - 5, projector.coeff_dim, 3, device=self.device, dtype=torch.float64 - ) - y = projector.from_grid(projector.to_grid(x)) - torch.testing.assert_close(y, x, atol=1e-12, rtol=1e-12) - - -class TestSwiGLUS2Equivariance(unittest.TestCase): - """Test default-grid equivariance of full-m and truncated SwiGLU-S2 activations.""" - - def setUp(self) -> None: - self.device = torch.device("cpu") - torch.manual_seed(0) - - def test_default_full_m_grid_counts_keep_s2_activation_equivariant(self) -> None: - """Default full-m S2 activation grids should keep SO(3) equivariance.""" - # Each case is (lmax, full_m_grid, fp64_tol, fp32_tol). - # e3nn full_m_grid is [R_phi, R_theta] after the square-grid lift. - # Lebedev full_m_grid is [precision, n_points]. - cases_by_method = { - "e3nn": [ - (2, [8, 8], 4.20e-7, 5.00e-6), # local: fp64=3.62e-7, fp32=4.77e-7 - (3, [12, 12], 8.10e-7, 5.00e-6), # local: fp64=7.04e-7, fp32=6.86e-7 - (4, [14, 14], 9.20e-7, 5.00e-6), # local: fp64=7.97e-7, fp32=1.55e-6 - (5, [18, 18], 1.70e-6, 5.00e-6), # local: fp64=1.48e-6, fp32=1.49e-6 - (6, [20, 20], 4.80e-6, 5.00e-6), # local: fp64=4.14e-6, fp32=2.27e-6 - (7, [24, 24], 3.70e-6, 5.00e-6), # local: fp64=3.19e-6, fp32=2.03e-6 - ], - "lebedev": [ - (2, [7, 26], 1.00e-12, 5.00e-6), # local: fp64=2.31e-14, fp32=2.38e-7 - (3, [9, 38], 1.00e-12, 5.00e-6), # local: fp64=3.58e-14, fp32=3.58e-7 - (4, [13, 74], 1.00e-12, 5.00e-6), # local: fp64=5.82e-14, fp32=6.56e-7 - (5, [15, 86], 1.00e-12, 5.00e-6), # local: fp64=3.22e-14, fp32=6.56e-7 - (6, [19, 146], 1.00e-12, 5.00e-6), # local: fp64=7.99e-14, fp32=8.35e-7 - (7, [21, 170], 1.00e-12, 5.00e-6), # local: fp64=6.86e-14, fp32=8.79e-7 - ], - } - dtype_cases = [ - (torch.float64, 0), - (torch.float32, 1), - ] - n_batch = 3 - n_focus = 1 - channels = 2 - - for dtype, tolerance_index in dtype_cases: - for method, cases in cases_by_method.items(): - for lmax, expected_full_m_grid, *tolerances in cases: - with self.subTest( - method=method, - dtype=dtype, - lmax=lmax, - grid=expected_full_m_grid, - ): - self._assert_default_full_m_s2_activation_equivariance( - grid_method=method, - lmax=lmax, - expected_full_m_grid=expected_full_m_grid, - n_batch=n_batch, - n_focus=n_focus, - channels=channels, - dtype=dtype, - tolerance=tolerances[tolerance_index], - ) - - def _assert_default_full_m_s2_activation_equivariance( - self, - *, - grid_method: str, - lmax: int, - expected_full_m_grid: list[int], - n_batch: int, - n_focus: int, - channels: int, - dtype: torch.dtype, - tolerance: float, - ) -> None: - """Assert full-m S2 activation equivariance for one method/dtype/lmax case.""" - torch.manual_seed(1234 + lmax) - default_grid = resolve_s2_grid_resolution( - lmax, - lmax, - method=grid_method, - ) - full_m_grid = ( - [max(default_grid), max(default_grid)] - if grid_method == "e3nn" - else default_grid - ) - self.assertEqual(full_m_grid, expected_full_m_grid) - - activation = SwiGLUS2Activation( - lmax=lmax, - channels=channels, - dtype=dtype, - n_focus=n_focus, - layout="ndfc", - grid_resolution_list=full_m_grid, - coefficient_layout="packed", - grid_method=grid_method, - mlp_bias=False, - trainable=False, - seed=17 + lmax, - ).to(self.device) - self.assertEqual(activation.grid_resolution_list, expected_full_m_grid) - - x = torch.randn( - n_batch, - (lmax + 1) ** 2, - n_focus, - 2 * channels, - device=self.device, - dtype=dtype, - ) - quat = _random_quaternion(n_batch, device=self.device, dtype=dtype) - d_matrix, _ = WignerDCalculator(lmax=lmax, dtype=dtype).to(self.device)(quat) - - y_rotated_input = activation(_rotate_ndfc(x, d_matrix)) - y_then_rotated = _rotate_ndfc(activation(x), d_matrix) - max_error = _max_abs_equivariance_error( - y_rotated_input, - y_then_rotated, - ) - - self.assertLessEqual(max_error, tolerance) - - def test_default_mmax_truncated_grid_counts_keep_s2_activation_z_equivariant( - self, - ) -> None: - """Default mmax-truncated S2 activation grids should keep z-equivariance.""" - # Each case is (lmax, mmax, truncated_grid, fp64_tol, fp32_tol). - # e3nn truncated_grid is [R_phi, R_theta] used by the m-major path. - # Lebedev truncated_grid is [precision, n_points]. - cases_by_method = { - "e3nn": { - 1: [ - (2, [6, 8], 2.80e-7, 5.00e-6), # local: fp64=2.36e-7, fp32=3.58e-7 - (3, [6, 12], 1.50e-7, 5.00e-6), # local: fp64=1.22e-7, fp32=5.96e-7 - (4, [6, 14], 1.33e-6, 5.00e-6), # local: fp64=1.12e-6, fp32=9.54e-7 - (5, [6, 18], 1.30e-7, 5.00e-6), # local: fp64=1.10e-7, fp32=1.43e-6 - (6, [6, 20], 9.00e-7, 5.00e-6), # local: fp64=7.64e-7, fp32=1.91e-6 - (7, [6, 24], 2.60e-7, 5.00e-6), # local: fp64=2.17e-7, fp32=1.91e-6 - ], - 2: [ - (2, [8, 8], 4.70e-7, 5.00e-6), # local: fp64=4.01e-7, fp32=8.34e-7 - (3, [8, 12], 7.00e-7, 5.00e-6), # local: fp64=5.99e-7, fp32=8.34e-7 - (4, [8, 14], 7.00e-7, 5.00e-6), # local: fp64=6.02e-7, fp32=1.67e-6 - (5, [8, 18], 1.40e-6, 5.00e-6), # local: fp64=1.19e-6, fp32=1.55e-6 - (6, [8, 20], 1.55e-6, 5.00e-6), # local: fp64=1.33e-6, fp32=2.15e-6 - (7, [8, 24], 1.65e-6, 5.00e-6), # local: fp64=1.41e-6, fp32=2.62e-6 - ], - }, - "lebedev": { - 1: [ - ( - 2, - [7, 26], - 1.00e-12, - 5.00e-6, - ), # local: fp64=2.31e-14, fp32=2.38e-7 - ( - 3, - [9, 38], - 1.00e-12, - 5.00e-6, - ), # local: fp64=3.55e-14, fp32=2.98e-7 - ( - 4, - [13, 74], - 1.00e-12, - 5.00e-6, - ), # local: fp64=1.04e-13, fp32=9.54e-7 - ( - 5, - [15, 86], - 1.00e-12, - 5.00e-6, - ), # local: fp64=9.34e-14, fp32=7.15e-7 - ( - 6, - [19, 146], - 1.00e-12, - 5.00e-6, - ), # local: fp64=8.56e-14, fp32=2.15e-6 - ( - 7, - [21, 170], - 1.00e-12, - 5.00e-6, - ), # local: fp64=2.08e-13, fp32=3.34e-6 - ], - 2: [ - ( - 2, - [7, 26], - 1.00e-12, - 5.00e-6, - ), # local: fp64=1.50e-14, fp32=2.38e-7 - ( - 3, - [9, 38], - 1.00e-12, - 5.00e-6, - ), # local: fp64=5.71e-14, fp32=3.58e-7 - ( - 4, - [13, 74], - 1.00e-12, - 5.00e-6, - ), # local: fp64=9.15e-14, fp32=5.96e-7 - ( - 5, - [15, 86], - 1.00e-12, - 5.00e-6, - ), # local: fp64=7.83e-14, fp32=4.77e-7 - ( - 6, - [19, 146], - 1.00e-12, - 5.00e-6, - ), # local: fp64=1.29e-13, fp32=9.54e-7 - ( - 7, - [21, 170], - 1.00e-12, - 5.00e-6, - ), # local: fp64=1.56e-13, fp32=1.43e-6 - ], - }, - } - dtype_cases = [ - (torch.float64, 0), - (torch.float32, 1), - ] - n_batch = 3 - n_focus = 2 - channels = 2 - - for dtype, tolerance_index in dtype_cases: - for method, cases_by_mmax in cases_by_method.items(): - for mmax, cases in cases_by_mmax.items(): - for lmax, expected_truncated_grid, *tolerances in cases: - with self.subTest( - method=method, - dtype=dtype, - lmax=lmax, - mmax=mmax, - grid=expected_truncated_grid, - ): - self._assert_default_mmax_truncated_grid_z_equivariance( - grid_method=method, - lmax=lmax, - mmax=mmax, - expected_truncated_grid=expected_truncated_grid, - n_batch=n_batch, - n_focus=n_focus, - channels=channels, - dtype=dtype, - tolerance=tolerances[tolerance_index], - ) - - def _assert_default_mmax_truncated_grid_z_equivariance( - self, - *, - grid_method: str, - lmax: int, - mmax: int, - expected_truncated_grid: list[int], - n_batch: int, - n_focus: int, - channels: int, - dtype: torch.dtype, - tolerance: float, - ) -> None: - """Assert mmax-truncated S2 activation z-equivariance for one case.""" - torch.manual_seed(2234 + lmax + 100 * mmax) - truncated_grid = resolve_s2_grid_resolution( - lmax, - mmax, - method=grid_method, - ) - self.assertEqual(truncated_grid, expected_truncated_grid) - - activation = SwiGLUS2Activation( - lmax=lmax, - mmax=mmax, - channels=channels, - dtype=dtype, - n_focus=n_focus, - layout="nfdc", - grid_resolution_list=truncated_grid, - coefficient_layout="m_major", - grid_method=grid_method, - mlp_bias=False, - trainable=False, - seed=27 + lmax + 100 * mmax, - ).to(self.device) - self.assertEqual(activation.grid_resolution_list, expected_truncated_grid) - - coeff_index = build_m_major_index(lmax, mmax, device=self.device) - x = torch.randn( - n_batch, - n_focus, - int(coeff_index.numel()), - 2 * channels, - device=self.device, - dtype=dtype, - ) - gamma = torch.randn(n_batch, device=self.device, dtype=dtype) - quaternion = quaternion_z_rotation(gamma) - d_matrix, _ = WignerDCalculator(lmax=lmax, dtype=dtype).to(self.device)( - quaternion - ) - d_matrix_reduced = d_matrix.index_select(1, coeff_index).index_select( - 2, - coeff_index, - ) - - y_rotated_input = activation(_rotate_nfdc(x, d_matrix_reduced)) - y_then_rotated = _rotate_nfdc(activation(x), d_matrix_reduced) - max_error = _max_abs_equivariance_error( - y_rotated_input, - y_then_rotated, - ) - - self.assertLessEqual(max_error, tolerance) diff --git a/source/tests/pt/model/test_sezm_model.py b/source/tests/pt/model/test_sezm_model.py index 0e13669f63..15175a5adc 100644 --- a/source/tests/pt/model/test_sezm_model.py +++ b/source/tests/pt/model/test_sezm_model.py @@ -704,7 +704,7 @@ def test_forward_backward_double_backward_matches_compile(self) -> None: # Inductor Triton kernels use different reduction order vs eager, # so float32 gradients can differ by ~1e-3 on GPU. grad_atol = 1.0e-5 if self.device == torch.device("cpu") else 2.0e-3 - grad_rtol = 1.0e-5 if self.device == torch.device("cpu") else 1.0e-4 + grad_rtol = 1.0e-5 if self.device == torch.device("cpu") else 3.0e-3 self.assertEqual(set(grads_dyn.keys()), set(grads_cmp.keys())) for name in grads_dyn.keys(): _assert_close_with_strict_warning( @@ -883,11 +883,13 @@ def _build_wrapper(use_compile: bool) -> ModelWrapper: m_cmp.train() out_e = m_eager(coord, atype, box=box) out_c = m_cmp(coord, atype, box=box) + energy_atol = 1.0e-6 if self.device == torch.device("cpu") else 5.0e-6 + energy_rtol = 1.0e-6 if self.device == torch.device("cpu") else 5.0e-4 _assert_close_with_strict_warning( out_e["energy"], out_c["energy"], - atol=1.0e-6, - rtol=1.0e-6, + atol=energy_atol, + rtol=energy_rtol, msg=f"multitask energy mismatch at {branch}", ) _assert_close_with_strict_warning( @@ -1911,11 +1913,13 @@ def test_forward_and_backward_match_eager(self) -> None: # === Forward === out_eager = model_eager(coord, atype, box=box) out_compile = model_compile(coord, atype, box=box) + energy_atol = 1.0e-6 if self.device == torch.device("cpu") else 1.0e-4 + energy_rtol = 1.0e-6 if self.device == torch.device("cpu") else 1.0e-4 _assert_close_with_strict_warning( out_eager["energy"], out_compile["energy"], - atol=1.0e-6, - rtol=1.0e-6, + atol=energy_atol, + rtol=energy_rtol, msg="LoRA energy mismatch", ) _assert_close_with_strict_warning( @@ -1932,7 +1936,7 @@ def test_forward_and_backward_match_eager(self) -> None: out_eager["energy"].sum().backward() out_compile["energy"].sum().backward() grad_atol = 1.0e-5 if self.device == torch.device("cpu") else 2.0e-3 - grad_rtol = 1.0e-5 if self.device == torch.device("cpu") else 1.0e-4 + grad_rtol = 1.0e-5 if self.device == torch.device("cpu") else 3.0e-3 force_grad_atol = 1.0e-2 force_grad_rtol = 1.0e-4 grads_eager = { From b13583af99fe46ff66349501fbcf76601304b542 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 7 Jun 2026 20:06:43 +0800 Subject: [PATCH 03/18] refactor(dpa4): triton rotation --- deepmd/pt/model/descriptor/sezm.py | 9 - deepmd/pt/model/descriptor/sezm_nn/block.py | 6 - .../model/descriptor/sezm_nn/cute/__init__.py | 36 + .../descriptor/sezm_nn/cute/so2_rotation.py | 916 +++++++++ .../pt/model/descriptor/sezm_nn/edge_cache.py | 65 +- deepmd/pt/model/descriptor/sezm_nn/so2.py | 75 +- .../descriptor/sezm_nn/triton/__init__.py | 33 +- .../descriptor/sezm_nn/triton/autograd.py | 837 -------- .../descriptor/sezm_nn/triton/constants.py | 46 - .../descriptor/sezm_nn/triton/custom_ops.py | 861 --------- .../descriptor/sezm_nn/triton/dispatch.py | 134 -- .../triton/kernels_edge_geometry_rbf.py | 550 ------ .../sezm_nn/triton/kernels_generic.py | 555 ------ .../sezm_nn/triton/kernels_small.py | 1317 ------------- .../descriptor/sezm_nn/triton/so2_rotation.py | 1715 +++++++++++++++++ deepmd/pt/model/model/sezm_model.py | 1 - doc/model/dpa4.md | 15 + .../pt/model/test_descriptor_sezm_triton.py | 1113 ++--------- source/tests/pt/model/test_sezm_model.py | 1 - 19 files changed, 2915 insertions(+), 5370 deletions(-) create mode 100644 deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py create mode 100644 deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py delete mode 100644 deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py delete mode 100644 deepmd/pt/model/descriptor/sezm_nn/triton/constants.py delete mode 100644 deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py delete mode 100644 deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py delete mode 100644 deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py delete mode 100644 deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py delete mode 100644 deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py create mode 100644 deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py diff --git a/deepmd/pt/model/descriptor/sezm.py b/deepmd/pt/model/descriptor/sezm.py index 257985a805..b20e6666a0 100644 --- a/deepmd/pt/model/descriptor/sezm.py +++ b/deepmd/pt/model/descriptor/sezm.py @@ -33,7 +33,6 @@ ) import math -import os from contextlib import ( contextmanager, ) @@ -555,12 +554,6 @@ def __init__( self.layer_scale = bool(layer_scale) self.use_amp = bool(use_amp) # and self.training self.trainable = bool(trainable) - self.use_triton = os.environ.get("DP_TRITON", "0").lower() in ( - "1", - "true", - "yes", - "on", - ) self.seed = seed self.random_gamma = bool(random_gamma) self.add_chg_spin_ebd = bool(add_chg_spin_ebd) @@ -899,7 +892,6 @@ def __init__( ffn_activation_function=self.ffn_activation_function, ffn_glu_activation=self.ffn_glu_activation, mlp_bias=self.mlp_bias, - use_triton=self.use_triton, eps=self.eps, dtype=self.dtype, seed=child_seed(seed_blocks, block_idx), @@ -1128,7 +1120,6 @@ def forward( # the model is roll-equivariant, so inference fixes gamma. random_gamma=self.random_gamma and self.training, wigner_calc=self.wigner_calc, - use_geometry_rbf_triton=(self.use_triton and not self.training), ) ebed_dim_0 = self.node_ebed_dims[0] # (node_lmax+1)^2 diff --git a/deepmd/pt/model/descriptor/sezm_nn/block.py b/deepmd/pt/model/descriptor/sezm_nn/block.py index d1ab12ad99..bc19d53cc6 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/block.py +++ b/deepmd/pt/model/descriptor/sezm_nn/block.py @@ -214,9 +214,6 @@ class SeZMInteractionBlock(nn.Module): - SO3Linear: l=0 bias - SO2Linear: l=0 bias - GatedActivation: gate linear bias - use_triton - If True, opt into fused Triton SO(2) rotation kernels inside - ``SO2Convolution`` when the runtime supports them. eps Small epsilon for numerical stability. dtype @@ -275,7 +272,6 @@ def __init__( ffn_activation_function: str, ffn_glu_activation: bool = True, mlp_bias: bool = False, - use_triton: bool = False, eps: float = 1e-7, dtype: torch.dtype, seed: int | list[int] | None, @@ -370,7 +366,6 @@ def __init__( self.ffn_activation_function = str(ffn_activation_function) self.ffn_glu_activation = bool(ffn_glu_activation) self.mlp_bias = bool(mlp_bias) - self.use_triton = bool(use_triton) self.eps = float(eps) self.dtype = dtype self.device = env.DEVICE @@ -436,7 +431,6 @@ def __init__( lebedev_quadrature=self.so2_lebedev_quadrature, activation_function=self.so2_activation_function, mlp_bias=self.mlp_bias, - use_triton=self.use_triton, eps=self.eps, dtype=dtype, seed=seed_so2_conv, diff --git a/deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py new file mode 100644 index 0000000000..63f146cbd9 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/cute/__init__.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +CuTe-DSL accelerated SO(2) rotation operators for SeZM / DPA4. + +This package provides a self-contained, ``torch.compile``-friendly implementation +of the two fused gather + batched-GEMM operators used by the SeZM SO(2) edge +convolution: + +* ``rotate_to_local`` : ``out[e] = wigner[e][coeff_index] @ x[src[e]]`` +* ``rotate_back`` : ``out[e] = wigner[e][:, coeff_index] @ x_local[e]`` + +The kernels are written with the NVIDIA CuTe DSL (``cutlass.cute``) and fuse the +Wigner-row/column gather and the source-node gather directly into the matmul, so +the large ``D_to_m`` / ``x_src`` intermediates are never materialized. They are +exposed through the modern ``torch.library.custom_op`` API (functional, with +``register_fake`` + ``register_autograd``) so that they compose correctly with +``torch.compile`` and autograd. + +The top-level entry points are re-exported here for convenience. +""" + +from __future__ import ( + annotations, +) + +from .so2_rotation import ( + SEZM_CUTE_AVAILABLE, + rotate_back_cute, + rotate_to_local_cute, +) + +__all__ = [ + "SEZM_CUTE_AVAILABLE", + "rotate_back_cute", + "rotate_to_local_cute", +] diff --git a/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py b/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py new file mode 100644 index 0000000000..f7bf36c743 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py @@ -0,0 +1,916 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001 +""" +CuTe-DSL fused SO(2) rotation kernels for SeZM / DPA4. + +Status and benchmark conclusion +------------------------------- +This implementation is experimental and is **not** wired into the production +SO(2) convolution. The shipping accelerated rotation path uses the Triton +block-diagonal kernels in ``sezm_nn/triton/so2_rotation.py`` (enabled by +``DP_TRITON_INFER``); this module is retained for reference and further +experiments. + +In head-to-head benchmarks against the compiled dense ``bmm`` and the Triton +kernels, the CuTe path had the best peak memory (roughly 2-4x lower than the +compiled dense path, lower than Triton) and won the forward pass, but its +``rotate_back`` backward -- and the forward+backward at large ``lmax`` (~10) -- +were slower than cuBLAS. The Triton block-diagonal kernels were chosen for +production because their speed (2-8x over the dense baseline) and native +``torch.compile`` composability outweigh the CuTe memory advantage in the target +``lmax`` 2-5, ``mmax == 1`` regime. + +Operator definitions (ground truth, fp32) +----------------------------------------- +Let ``x`` be packed node features, ``src`` the per-edge source-node indices, +``wigner`` the per-edge block-diagonal Wigner-D matrices, ``coeff_index`` the +``m``-major reduced-layout indices and ``dim_full = D`` the full packed SO(3) +dimension (``D <= Dw`` where ``wigner`` is ``(E, Dw, Dw)``). + +``rotate_to_local`` lifts global node features into the per-edge local frame and +truncates to the reduced layout in one fused step:: + + out[e, i, c] = sum_j wigner[e, coeff_index[i], j] * x[src[e], j, c] + # i in [0, Dm), j in [0, D), c in [0, C) + +``rotate_back`` is the (column-selected) inverse rotation:: + + out[e, i, c] = sum_j wigner[e, i, coeff_index[j]] * x_local[e, j, c] + # i in [0, D), j in [0, Dm), c in [0, C) + +Both operators are batched (one tiny GEMM per edge) with two gathers fused in: +the Wigner row/column selection by ``coeff_index`` and the source-node gather by +``src``. Fusing the gathers means the large ``D_to_m`` ``(E, Dm, D)`` and +``x_src`` ``(E, D, C)`` intermediates produced by the eager ``index_select`` + +``bmm`` reference are never written to or read from global memory, which is the +main source of the speed/peak-memory advantage. + +Backward (both feature *and* ``wigner`` gradients, required for forces) +---------------------------------------------------------------------- +``rotate_to_local``:: + + grad_edge[e, j, c] = sum_i wigner[e, coeff_index[i], j] * grad_out[e, i, c] + grad_x = scatter_add(grad_edge, dim=0, index=src) # (N, D, C) + grad_wigner[e, coeff_index[i], j] = sum_c grad_out[e, i, c] * x[src[e], j, c] + +``rotate_back``:: + + grad_x_local[e, j, c] = sum_i wigner[e, i, coeff_index[j]] * grad_out[e, i, c] + grad_wigner[e, i, coeff_index[j]] = sum_c grad_out[e, i, c] * x_local[e, j, c] + +(all other entries of ``grad_wigner`` are zero). + +Kernel design +------------- +Every kernel computes, per edge, a small matrix product ``out = A @ B`` (with +one operand gathered) using a **2D register-blocked GEMM**: + +* one CUDA block per edge; +* the operand whose layout is ``(K, C)`` (the source-node / local / grad_out + tile) is staged once into shared memory; +* each thread owns a ``TM x TN`` register tile of the output and sweeps the + contraction dimension ``K``, loading ``TM`` values of ``A`` and ``TN`` values + of ``B`` per step and issuing ``TM*TN`` FFMAs. This pushes the load:FFMA ratio + to ``(TM+TN)/(TM*TN)`` so the kernel is compute-bound rather than + load/store-unit bound; +* the per-output-row Wigner index gather (``coeff_index``) is hoisted out of the + contraction loop into registers. + +The two ``grad_wigner`` kernels are batched outer products (contraction over the +channel axis ``C``) and use the same register-blocked skeleton with a 2D tile +sweep over the ``(Dm, D)`` output. When both per-edge operands fit in shared +memory (small/medium ``lmax``) both are staged there; otherwise only +``grad_out`` is staged and the other operand streams from global memory through +L1. The ``rotate_to_local`` ``grad_x`` contribution is fused with its +source-node scatter via atomic adds, so neither a ``grad_edge`` intermediate nor +a separate ``index_add`` is materialized. + +All accumulation is fp32 (no TF32), keeping the potential-energy surface smooth. + +Composability +------------- +The kernels are wrapped with ``torch.library.custom_op`` (functional, +``mutates_args=()``) plus ``register_fake`` and ``register_autograd``. The +backward is itself a custom op, so ``torch.compile`` can include and +differentiate the whole thing as an opaque, side-effect-free operator. Kernels +are launched on torch's current CUDA stream so they order correctly with the +surrounding eager / compiled graph. +""" + +from __future__ import ( + annotations, +) + +import threading +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +from torch import ( + Tensor, +) + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + ) + + from cuda.bindings import driver as _cuda_driver + +try: + import cutlass + import cutlass.cute as cute + import cutlass.torch as cutlass_torch + from cutlass.cute.runtime import ( + from_dlpack, + ) + + SEZM_CUTE_AVAILABLE = True +except Exception: # pragma: no cover - import guard for non-CuTe environments + SEZM_CUTE_AVAILABLE = False + + +# === Kernel tuning constants ================================================= +# Register-tile dimensions (TM output rows x TN output cols per thread) and the +# block thread geometry. ``C`` (= 64) is the channel axis and is the N dimension +# for the matmul-like kernels; ``TN`` divides ``C``. +_TM = 4 +_TN = 4 +_BLOCK_ROWS = 16 # block.y for matmul-like kernels (block.x = C // TN) + +# grad_wigner: budget (bytes) below which both operands are staged in shared +# memory (the fast path); above it (e.g. lmax=10) only grad_out is staged and +# the other operand streams from global memory through L1. +_GW_SMEM_BUDGET = 46000 + + +def _gw_tile(D: int, Dm: int, C: int) -> tuple[int, int, int, int, bool]: + """Pick (TM, TN, BX, BY, both_in_smem) for a grad_wigner output of (M, N). + + The register tile and block geometry are chosen so the block is well + occupied for the given output size, and both operands are staged in shared + memory when the per-edge tiles fit inside ``_GW_SMEM_BUDGET``. + """ + both = (Dm + D) * C * 4 <= _GW_SMEM_BUDGET + if D <= 20: # small output (e.g. lmax=3): keep the tile/block small + return 2, 2, 16, 16, both + if D <= 50: # medium output (e.g. lmax=5) + return 8, 8, 8, 8, both + if both: # large output that still fits both operands (e.g. lmax=7) + return 8, 4, 8, 16, both + return 8, 8, 8, 8, both # large output, only grad_out staged (e.g. lmax=10) + + +# === Eager reference (ground truth, also used as fallback) =================== +def _rotate_to_local_eager( + x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int +) -> Tensor: + """Reference ``D_to_m @ x[src]`` used for fallback and validation.""" + d_to_m = wigner[:, :dim_full, :dim_full].index_select(1, coeff_index) + return torch.bmm(d_to_m, x.index_select(0, src)) + + +def _rotate_back_eager( + x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int +) -> Tensor: + """Reference ``Dt_from_m @ x_local`` used for fallback and validation.""" + dt_from_m = wigner[:, :dim_full, :dim_full].index_select(2, coeff_index) + return torch.bmm(dt_from_m, x_local) + + +if SEZM_CUTE_AVAILABLE: + _F32 = cutlass.Float32 + _I64 = cutlass.Int64 + + # ------------------------------------------------------------------ + # Family 1: out(M, C) = A(M, K) @ S(K, C), with S staged in shared + # memory and A read from the Wigner tensor with a per-element gather. + # Specialized by how A[m, k] maps into the (Dw, Dw) Wigner block. + # ------------------------------------------------------------------ + def _build_rotate_to_local_fwd(D: int, Dm: int, C: int) -> Callable: + """``out[m=i, n=c] = sum_{k=j} wigner[e, idx[m], k] * x[src[e], k, n]``.""" + M, K = Dm, D + TM, TN, BY = _TM, _TN, _BLOCK_ROWS + BX = C // TN + T = BX * BY + + @cute.kernel + def kernel(m_x, m_src, m_w, m_idx, m_out) -> None: + e, _, _ = cute.arch.block_idx() + cx, ry, _ = cute.arch.thread_idx() + smem = cute.arch.alloc_smem(_F32, K * C) + s_s = cute.make_tensor(smem, cute.make_layout((K, C), stride=(C, 1))) + src_e = m_src[e] + x_node = m_x[src_e, None, None] + tid = ry * BX + cx + for kk in cutlass.range(tid, K * C, T): + s_s[kk // C, kk % C] = x_node[kk // C, kk % C] + cute.arch.sync_threads() + + w_e = m_w[e, None, None] + out_e = m_out[e, None, None] + for rt0 in cutlass.range(ry * TM, M, BY * TM): + acc = cute.make_fragment((TM, TN), _F32) + wi = cute.make_fragment((TM,), _I64) + bf = cute.make_fragment((TN,), _F32) + for t in range(TM): + wi[t] = m_idx[(rt0 + t) % M] # gathered Wigner row + for n in range(TN): + acc[t, n] = _F32(0.0) + for k in cutlass.range(K): + for n in range(TN): + bf[n] = s_s[k, cx * TN + n] + for t in range(TM): + a = w_e[wi[t], k] + for n in range(TN): + acc[t, n] = acc[t, n] + a * bf[n] + for t in range(TM): + m = rt0 + t + if m < M: + for n in range(TN): + out_e[m, cx * TN + n] = acc[t, n] + + @cute.jit + def host(m_x, m_src, m_w, m_idx, m_out, stream: _cuda_driver.CUstream) -> None: + e = m_out.shape[0] + kernel(m_x, m_src, m_w, m_idx, m_out).launch( + grid=[e, 1, 1], block=[BX, BY, 1], stream=stream + ) + + return host + + def _build_rotate_back_fwd(D: int, Dm: int, C: int) -> Callable: + """``out[m=i, n=c] = sum_{k=j} wigner[e, m, idx[k]] * x_local[e, k, n]``.""" + M, K = D, Dm + TM, TN, BY = _TM, _TN, _BLOCK_ROWS + BX = C // TN + T = BX * BY + + @cute.kernel + def kernel(m_xl, m_w, m_idx, m_out) -> None: + e, _, _ = cute.arch.block_idx() + cx, ry, _ = cute.arch.thread_idx() + smem = cute.arch.alloc_smem(_F32, K * C) + s_s = cute.make_tensor(smem, cute.make_layout((K, C), stride=(C, 1))) + xl_e = m_xl[e, None, None] + tid = ry * BX + cx + for kk in cutlass.range(tid, K * C, T): + s_s[kk // C, kk % C] = xl_e[kk // C, kk % C] + cute.arch.sync_threads() + + w_e = m_w[e, None, None] + out_e = m_out[e, None, None] + for rt0 in cutlass.range(ry * TM, M, BY * TM): + acc = cute.make_fragment((TM, TN), _F32) + wr = cute.make_fragment((TM,), _I64) + bf = cute.make_fragment((TN,), _F32) + for t in range(TM): + wr[t] = (rt0 + t) % M # direct Wigner row + for n in range(TN): + acc[t, n] = _F32(0.0) + for k in cutlass.range(K): + kk = m_idx[k] # gathered Wigner column + for n in range(TN): + bf[n] = s_s[k, cx * TN + n] + for t in range(TM): + a = w_e[wr[t], kk] + for n in range(TN): + acc[t, n] = acc[t, n] + a * bf[n] + for t in range(TM): + m = rt0 + t + if m < M: + for n in range(TN): + out_e[m, cx * TN + n] = acc[t, n] + + @cute.jit + def host(m_xl, m_w, m_idx, m_out, stream: _cuda_driver.CUstream) -> None: + e = m_out.shape[0] + kernel(m_xl, m_w, m_idx, m_out).launch( + grid=[e, 1, 1], block=[BX, BY, 1], stream=stream + ) + + return host + + def _build_rotate_to_local_bwd_dx(D: int, Dm: int, C: int) -> Callable: + """``grad_x[src[e], m=j, n=c] += sum_{k=i} wigner[e, idx[k], m] * grad_out[e, k, n]``. + + The per-edge gradient and the scatter-add into ``grad_x`` (indexed by + ``src``) are fused: each block accumulates its tile and atomically adds it + into the destination node. This avoids a materialized ``grad_edge`` tensor + and a separate ``index_add`` pass. + """ + M, K = D, Dm + TM, TN, BY = _TM, _TN, _BLOCK_ROWS + BX = C // TN + T = BX * BY + + @cute.kernel + def kernel(m_go, m_w, m_src, m_idx, m_gx) -> None: + e, _, _ = cute.arch.block_idx() + cx, ry, _ = cute.arch.thread_idx() + smem = cute.arch.alloc_smem(_F32, K * C) + s_s = cute.make_tensor(smem, cute.make_layout((K, C), stride=(C, 1))) + go_e = m_go[e, None, None] + tid = ry * BX + cx + for kk in cutlass.range(tid, K * C, T): + s_s[kk // C, kk % C] = go_e[kk // C, kk % C] + cute.arch.sync_threads() + + w_e = m_w[e, None, None] + gx_node = m_gx[m_src[e], None, None] # (D, C) view into grad_x[src] + gx_base = gx_node.iterator # contiguous (C, 1): element (m, c) -> m*C + c + for rt0 in cutlass.range(ry * TM, M, BY * TM): + acc = cute.make_fragment((TM, TN), _F32) + wc = cute.make_fragment((TM,), _I64) + bf = cute.make_fragment((TN,), _F32) + for t in range(TM): + wc[t] = (rt0 + t) % M # direct Wigner column (= output row m) + for n in range(TN): + acc[t, n] = _F32(0.0) + for k in cutlass.range(K): + kk = m_idx[k] # gathered Wigner row + for n in range(TN): + bf[n] = s_s[k, cx * TN + n] + for t in range(TM): + a = w_e[kk, wc[t]] + for n in range(TN): + acc[t, n] = acc[t, n] + a * bf[n] + for t in range(TM): + m = rt0 + t + if m < M: + for n in range(TN): + cute.arch.atomic_add( + gx_base + (m * C + cx * TN + n), acc[t, n] + ) + + @cute.jit + def host(m_go, m_w, m_src, m_idx, m_gx, stream: _cuda_driver.CUstream) -> None: + e = m_go.shape[0] + kernel(m_go, m_w, m_src, m_idx, m_gx).launch( + grid=[e, 1, 1], block=[BX, BY, 1], stream=stream + ) + + return host + + def _build_rotate_back_bwd_dx(D: int, Dm: int, C: int) -> Callable: + """``grad_x_local[m=j, n=c] = sum_{k=i} wigner[e, k, idx[m]] * grad_out[e, k, n]``.""" + M, K = Dm, D + TM, TN, BY = _TM, _TN, _BLOCK_ROWS + BX = C // TN + T = BX * BY + + @cute.kernel + def kernel(m_go, m_w, m_idx, m_gxl) -> None: + e, _, _ = cute.arch.block_idx() + cx, ry, _ = cute.arch.thread_idx() + smem = cute.arch.alloc_smem(_F32, K * C) + s_s = cute.make_tensor(smem, cute.make_layout((K, C), stride=(C, 1))) + go_e = m_go[e, None, None] + tid = ry * BX + cx + for kk in cutlass.range(tid, K * C, T): + s_s[kk // C, kk % C] = go_e[kk // C, kk % C] + cute.arch.sync_threads() + + w_e = m_w[e, None, None] + gxl_e = m_gxl[e, None, None] + for rt0 in cutlass.range(ry * TM, M, BY * TM): + acc = cute.make_fragment((TM, TN), _F32) + wc = cute.make_fragment((TM,), _I64) + bf = cute.make_fragment((TN,), _F32) + for t in range(TM): + wc[t] = m_idx[(rt0 + t) % M] # gathered Wigner column + for n in range(TN): + acc[t, n] = _F32(0.0) + for k in cutlass.range(K): + for n in range(TN): + bf[n] = s_s[k, cx * TN + n] + for t in range(TM): + a = w_e[k, wc[t]] + for n in range(TN): + acc[t, n] = acc[t, n] + a * bf[n] + for t in range(TM): + m = rt0 + t + if m < M: + for n in range(TN): + gxl_e[m, cx * TN + n] = acc[t, n] + + @cute.jit + def host(m_go, m_w, m_idx, m_gxl, stream: _cuda_driver.CUstream) -> None: + e = m_go.shape[0] + kernel(m_go, m_w, m_idx, m_gxl).launch( + grid=[e, 1, 1], block=[BX, BY, 1], stream=stream + ) + + return host + + # ------------------------------------------------------------------ + # Family 2: grad_wigner = grad_out @ other^T (contraction over the + # channel axis C). 2D register-blocked sweep over the (M, N) output, + # grad_out staged in shared memory, other read from global memory. + # ------------------------------------------------------------------ + def _build_rotate_to_local_bwd_dw(D: int, Dm: int, C: int) -> Callable: + """``grad_wigner[e, idx[m=i], n=j] = sum_{k=c} grad_out[e, m, k] * x[src[e], n, k]``.""" + M, N, K = Dm, D, C + TM, TN, BX, BY, both = _gw_tile(D, Dm, C) + T = BX * BY + + @cute.kernel + def kernel(m_go, m_x, m_src, m_idx, m_gw) -> None: + e, _, _ = cute.arch.block_idx() + cx, ry, _ = cute.arch.thread_idx() + sgo = cute.arch.alloc_smem(_F32, M * C) + s_go = cute.make_tensor(sgo, cute.make_layout((M, C), stride=(C, 1))) + go_e = m_go[e, None, None] + src_e = m_src[e] + x_node = m_x[src_e, None, None] # (N=D, C) + tid = ry * BX + cx + for kk in cutlass.range(tid, M * C, T): + s_go[kk // C, kk % C] = go_e[kk // C, kk % C] + # Optionally stage the second operand in shared memory too. + sx = cute.arch.alloc_smem(_F32, (N * C) if both else 1) + s_x = cute.make_tensor( + sx, cute.make_layout(((N, C) if both else (1, 1)), stride=(C, 1)) + ) + if cutlass.const_expr(both): + for kk in cutlass.range(tid, N * C, T): + s_x[kk // C, kk % C] = x_node[kk // C, kk % C] + cute.arch.sync_threads() + + gw_e = m_gw[e, None, None] + for mt0 in cutlass.range(ry * TM, M, BY * TM): + orow = cute.make_fragment((TM,), _I64) + rt = cute.make_fragment((TM,), cutlass.Int32) + for t in range(TM): + rt[t] = (mt0 + t) % M # clamped smem row (hoisted out of K loop) + orow[t] = m_idx[rt[t]] # gathered output row + for nt0 in cutlass.range(cx * TN, N, BX * TN): + acc = cute.make_fragment((TM, TN), _F32) + af = cute.make_fragment((TM,), _F32) + bf = cute.make_fragment((TN,), _F32) + ct = cute.make_fragment((TN,), cutlass.Int32) + for n in range(TN): + ct[n] = (nt0 + n) % N # clamped col (hoisted) + for t in range(TM): + for n in range(TN): + acc[t, n] = _F32(0.0) + for k in cutlass.range(K): + for t in range(TM): + af[t] = s_go[rt[t], k] + if cutlass.const_expr(both): + for n in range(TN): + bf[n] = s_x[ct[n], k] + else: + for n in range(TN): + bf[n] = x_node[ct[n], k] + for t in range(TM): + for n in range(TN): + acc[t, n] = acc[t, n] + af[t] * bf[n] + for t in range(TM): + if mt0 + t < M: + for n in range(TN): + if nt0 + n < N: + gw_e[orow[t], nt0 + n] = acc[t, n] + + @cute.jit + def host(m_go, m_x, m_src, m_idx, m_gw, stream: _cuda_driver.CUstream) -> None: + e = m_go.shape[0] + kernel(m_go, m_x, m_src, m_idx, m_gw).launch( + grid=[e, 1, 1], block=[BX, BY, 1], stream=stream + ) + + return host + + def _build_rotate_back_bwd_dw(D: int, Dm: int, C: int) -> Callable: + """``grad_wigner[e, m=i, idx[n=j]] = sum_{k=c} grad_out[e, m, k] * x_local[e, n, k]``.""" + M, N, K = D, Dm, C + TM, TN, BX, BY, both = _gw_tile(D, Dm, C) + T = BX * BY + + @cute.kernel + def kernel(m_go, m_xl, m_idx, m_gw) -> None: + e, _, _ = cute.arch.block_idx() + cx, ry, _ = cute.arch.thread_idx() + sgo = cute.arch.alloc_smem(_F32, M * C) + s_go = cute.make_tensor(sgo, cute.make_layout((M, C), stride=(C, 1))) + go_e = m_go[e, None, None] + xl_e = m_xl[e, None, None] # (N=Dm, C) + tid = ry * BX + cx + for kk in cutlass.range(tid, M * C, T): + s_go[kk // C, kk % C] = go_e[kk // C, kk % C] + sx = cute.arch.alloc_smem(_F32, (N * C) if both else 1) + s_x = cute.make_tensor( + sx, cute.make_layout(((N, C) if both else (1, 1)), stride=(C, 1)) + ) + if cutlass.const_expr(both): + for kk in cutlass.range(tid, N * C, T): + s_x[kk // C, kk % C] = xl_e[kk // C, kk % C] + cute.arch.sync_threads() + + gw_e = m_gw[e, None, None] + for mt0 in cutlass.range(ry * TM, M, BY * TM): + rt = cute.make_fragment((TM,), cutlass.Int32) + for t in range(TM): + rt[t] = (mt0 + t) % M # clamped smem row (hoisted out of K loop) + for nt0 in cutlass.range(cx * TN, N, BX * TN): + acc = cute.make_fragment((TM, TN), _F32) + ocol = cute.make_fragment((TN,), _I64) + ct = cute.make_fragment((TN,), cutlass.Int32) + af = cute.make_fragment((TM,), _F32) + bf = cute.make_fragment((TN,), _F32) + for n in range(TN): + ct[n] = (nt0 + n) % N # clamped col (hoisted) + ocol[n] = m_idx[ct[n]] # gathered output column + for t in range(TM): + for n in range(TN): + acc[t, n] = _F32(0.0) + for k in cutlass.range(K): + for t in range(TM): + af[t] = s_go[rt[t], k] + if cutlass.const_expr(both): + for n in range(TN): + bf[n] = s_x[ct[n], k] + else: + for n in range(TN): + bf[n] = xl_e[ct[n], k] + for t in range(TM): + for n in range(TN): + acc[t, n] = acc[t, n] + af[t] * bf[n] + for t in range(TM): + i = mt0 + t + if i < M: + for n in range(TN): + if nt0 + n < N: + gw_e[i, ocol[n]] = acc[t, n] + + @cute.jit + def host(m_go, m_xl, m_idx, m_gw, stream: _cuda_driver.CUstream) -> None: + e = m_go.shape[0] + kernel(m_go, m_xl, m_idx, m_gw).launch( + grid=[e, 1, 1], block=[BX, BY, 1], stream=stream + ) + + return host + + # === Compiled-kernel cache ============================================== + _compiled_cache: dict[tuple, Any] = {} + _cache_lock = threading.Lock() + + def _get_compiled(key: tuple, builder: Callable, example_args: tuple) -> Any: + """Return a JIT-compiled host function, compiling and caching on miss.""" + comp = _compiled_cache.get(key) + if comp is not None: + return comp + with _cache_lock: + comp = _compiled_cache.get(key) + if comp is None: + host = builder(*key[1:]) + comp = cute.compile(host, *example_args) + _compiled_cache[key] = comp + return comp + + def _cute_f(t: Tensor) -> Any: + """Wrap a contiguous (>=2D) fp32 tensor as a CuTe tensor (last dim leading).""" + return from_dlpack(t).mark_layout_dynamic(leading_dim=t.dim() - 1) + + def _cute_i(t: Tensor) -> Any: + """Wrap a contiguous 1D int64 tensor as a CuTe tensor.""" + return from_dlpack(t).mark_layout_dynamic() + + # === Low-level kernel dispatch (operate on plain, detached tensors) ====== + def _launch_rotate_to_local_fwd( + x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int + ) -> Tensor: + e = src.shape[0] + d, dm, c = dim_full, coeff_index.shape[0], x.shape[2] + out = torch.empty((e, dm, c), dtype=x.dtype, device=x.device) + m_x, m_src, m_w = _cute_f(x), _cute_i(src), _cute_f(wigner) + m_idx, m_out = _cute_i(coeff_index), _cute_f(out) + stream = cutlass_torch.current_stream() + comp = _get_compiled( + ("rtl_fwd", d, dm, c), + _build_rotate_to_local_fwd, + (m_x, m_src, m_w, m_idx, m_out, stream), + ) + comp(m_x, m_src, m_w, m_idx, m_out, stream) + return out + + def _launch_rotate_back_fwd( + x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int + ) -> Tensor: + e = x_local.shape[0] + d, dm, c = dim_full, coeff_index.shape[0], x_local.shape[2] + out = torch.empty((e, d, c), dtype=x_local.dtype, device=x_local.device) + m_xl, m_w = _cute_f(x_local), _cute_f(wigner) + m_idx, m_out = _cute_i(coeff_index), _cute_f(out) + stream = cutlass_torch.current_stream() + comp = _get_compiled( + ("rb_fwd", d, dm, c), + _build_rotate_back_fwd, + (m_xl, m_w, m_idx, m_out, stream), + ) + comp(m_xl, m_w, m_idx, m_out, stream) + return out + + def _launch_rotate_to_local_bwd( + grad_out: Tensor, + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> tuple[Tensor, Tensor]: + n, e = x.shape[0], src.shape[0] + d, dm, c = dim_full, coeff_index.shape[0], x.shape[2] + stream = cutlass_torch.current_stream() + + # grad_x: per-edge gradient fused with the scatter-add into the source + # node via atomic adds (no materialized grad_edge, no separate index_add). + grad_x = torch.zeros((n, d, c), dtype=x.dtype, device=x.device) + m_go, m_w = _cute_f(grad_out), _cute_f(wigner) + m_src, m_idx, m_gx = _cute_i(src), _cute_i(coeff_index), _cute_f(grad_x) + comp_dx = _get_compiled( + ("rtl_bwd_dx", d, dm, c), + _build_rotate_to_local_bwd_dx, + (m_go, m_w, m_src, m_idx, m_gx, stream), + ) + comp_dx(m_go, m_w, m_src, m_idx, m_gx, stream) + + # grad_wigner: per-edge outer product written into the gathered rows. + grad_wigner = torch.zeros_like(wigner) + m_x, m_gw = _cute_f(x), _cute_f(grad_wigner) + comp_dw = _get_compiled( + ("rtl_bwd_dw", d, dm, c), + _build_rotate_to_local_bwd_dw, + (m_go, m_x, m_src, m_idx, m_gw, stream), + ) + comp_dw(m_go, m_x, m_src, m_idx, m_gw, stream) + return grad_x, grad_wigner + + def _launch_rotate_back_bwd( + grad_out: Tensor, + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> tuple[Tensor, Tensor]: + e = x_local.shape[0] + d, dm, c = dim_full, coeff_index.shape[0], x_local.shape[2] + stream = cutlass_torch.current_stream() + + grad_x_local = torch.empty( + (e, dm, c), dtype=x_local.dtype, device=x_local.device + ) + m_go, m_w = _cute_f(grad_out), _cute_f(wigner) + m_idx, m_gxl = _cute_i(coeff_index), _cute_f(grad_x_local) + comp_dx = _get_compiled( + ("rb_bwd_dx", d, dm, c), + _build_rotate_back_bwd_dx, + (m_go, m_w, m_idx, m_gxl, stream), + ) + comp_dx(m_go, m_w, m_idx, m_gxl, stream) + + grad_wigner = torch.zeros_like(wigner) + m_xl, m_gw = _cute_f(x_local), _cute_f(grad_wigner) + comp_dw = _get_compiled( + ("rb_bwd_dw", d, dm, c), + _build_rotate_back_bwd_dw, + (m_go, m_xl, m_idx, m_gw, stream), + ) + comp_dw(m_go, m_xl, m_idx, m_gw, stream) + return grad_x_local, grad_wigner + + # === torch.library custom ops =========================================== + # Forward + backward are registered as functional custom ops so the whole + # operator is opaque to torch.compile yet correctly differentiable. + + @torch.library.custom_op( + "sezm_cute::rotate_to_local", mutates_args=(), device_types="cuda" + ) + def _op_rotate_to_local( + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> Tensor: + return _launch_rotate_to_local_fwd( + x.detach().contiguous(), + src.detach().contiguous(), + wigner.detach().contiguous(), + coeff_index.detach().contiguous(), + int(dim_full), + ) + + @_op_rotate_to_local.register_fake + def _( + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> Tensor: + return x.new_empty((src.shape[0], coeff_index.shape[0], x.shape[2])) + + @torch.library.custom_op( + "sezm_cute::rotate_to_local_bwd", mutates_args=(), device_types="cuda" + ) + def _op_rotate_to_local_bwd( + grad_out: Tensor, + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> tuple[Tensor, Tensor]: + return _launch_rotate_to_local_bwd( + grad_out.detach().contiguous(), + x.detach().contiguous(), + src.detach().contiguous(), + wigner.detach().contiguous(), + coeff_index.detach().contiguous(), + int(dim_full), + ) + + @_op_rotate_to_local_bwd.register_fake + def _( + grad_out: Tensor, + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> tuple[Tensor, Tensor]: + return torch.empty_like(x), torch.empty_like(wigner) + + def _rtl_setup_context(ctx: Any, inputs: tuple, output: Tensor) -> None: + x, src, wigner, coeff_index, dim_full = inputs + ctx.save_for_backward(x, src, wigner, coeff_index) + ctx.dim_full = int(dim_full) + + def _rtl_backward(ctx: Any, grad_out: Tensor) -> tuple: + x, src, wigner, coeff_index = ctx.saved_tensors + grad_x, grad_wigner = torch.ops.sezm_cute.rotate_to_local_bwd( + grad_out, x, src, wigner, coeff_index, ctx.dim_full + ) + return grad_x, None, grad_wigner, None, None + + _op_rotate_to_local.register_autograd( + _rtl_backward, setup_context=_rtl_setup_context + ) + + @torch.library.custom_op( + "sezm_cute::rotate_back", mutates_args=(), device_types="cuda" + ) + def _op_rotate_back( + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> Tensor: + return _launch_rotate_back_fwd( + x_local.detach().contiguous(), + wigner.detach().contiguous(), + coeff_index.detach().contiguous(), + int(dim_full), + ) + + @_op_rotate_back.register_fake + def _( + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> Tensor: + return x_local.new_empty((x_local.shape[0], dim_full, x_local.shape[2])) + + @torch.library.custom_op( + "sezm_cute::rotate_back_bwd", mutates_args=(), device_types="cuda" + ) + def _op_rotate_back_bwd( + grad_out: Tensor, + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> tuple[Tensor, Tensor]: + return _launch_rotate_back_bwd( + grad_out.detach().contiguous(), + x_local.detach().contiguous(), + wigner.detach().contiguous(), + coeff_index.detach().contiguous(), + int(dim_full), + ) + + @_op_rotate_back_bwd.register_fake + def _( + grad_out: Tensor, + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + ) -> tuple[Tensor, Tensor]: + return torch.empty_like(x_local), torch.empty_like(wigner) + + def _rb_setup_context(ctx: Any, inputs: tuple, output: Tensor) -> None: + x_local, wigner, coeff_index, dim_full = inputs + ctx.save_for_backward(x_local, wigner, coeff_index) + ctx.dim_full = int(dim_full) + + def _rb_backward(ctx: Any, grad_out: Tensor) -> tuple: + x_local, wigner, coeff_index = ctx.saved_tensors + grad_x_local, grad_wigner = torch.ops.sezm_cute.rotate_back_bwd( + grad_out, x_local, wigner, coeff_index, ctx.dim_full + ) + return grad_x_local, grad_wigner, None, None + + _op_rotate_back.register_autograd(_rb_backward, setup_context=_rb_setup_context) + + +# === Public API ============================================================== +def _cute_usable(*tensors: Tensor) -> bool: + """Return True when the CuTe fast path is available for these tensors.""" + if not SEZM_CUTE_AVAILABLE: + return False + return all( + t.is_cuda and t.dtype == torch.float32 for t in tensors if t.is_floating_point() + ) + + +def rotate_to_local_cute( + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + """ + Fused ``global -> local reduced`` rotation (CuTe fast path with eager fallback). + + Parameters + ---------- + x + Node features with shape ``(N, D, C)``. + src + Source-node indices with shape ``(E,)``. + wigner + Packed Wigner-D matrices with shape ``(E, Dw, Dw)`` (``Dw >= dim_full``). + coeff_index + Reduced-layout row indices with shape ``(Dm,)``. + dim_full + Full packed SO(3) dimension ``D``. + + Returns + ------- + Tensor + Rotated reduced-layout edge features with shape ``(E, Dm, C)``. + + Notes + ----- + Experimental path that is not used in production. See the module docstring + for the benchmark conclusion and why the Triton kernels were chosen instead. + """ + if _cute_usable(x, wigner) and src.numel() > 0: + return torch.ops.sezm_cute.rotate_to_local( + x, src, wigner, coeff_index, int(dim_full) + ) + return _rotate_to_local_eager(x, src, wigner, coeff_index, dim_full) + + +def rotate_back_cute( + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + """ + Fused ``local reduced -> global`` rotation (CuTe fast path with eager fallback). + + Parameters + ---------- + x_local + Reduced-layout edge features with shape ``(E, Dm, C)``. + wigner + Packed Wigner-D matrices with shape ``(E, Dw, Dw)`` (``Dw >= dim_full``). + coeff_index + Reduced-layout column indices with shape ``(Dm,)``. + dim_full + Full packed SO(3) dimension ``D``. + + Returns + ------- + Tensor + Lifted global-layout edge features with shape ``(E, D, C)``. + + Notes + ----- + Experimental path that is not used in production. See the module docstring + for the benchmark conclusion and why the Triton kernels were chosen instead. + """ + if _cute_usable(x_local, wigner) and x_local.shape[0] > 0: + return torch.ops.sezm_cute.rotate_back( + x_local, wigner, coeff_index, int(dim_full) + ) + return _rotate_back_eager(x_local, wigner, coeff_index, dim_full) diff --git a/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py b/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py index 707a6ef8db..2174dfd4b3 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py +++ b/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py @@ -24,9 +24,6 @@ rearrange, ) -from .triton import ( - edge_geometry_rbf_triton, -) from .utils import ( get_promoted_dtype, nvtx_range, @@ -222,7 +219,6 @@ def build_edge_cache( n_radial: int, random_gamma: bool, wigner_calc: WignerCalculatorFn, - use_geometry_rbf_triton: bool = False, ) -> EdgeFeatureCache: """ Build the global edge cache from DeePMD padded neighbor list. @@ -284,9 +280,6 @@ def build_edge_cache( wigner_calc Callable that converts edge-aligned quaternions into packed Wigner-D blocks. - use_geometry_rbf_triton - Whether to allow the standard-path fused Triton geometry/RBF chain - ``gather -> vec -> len -> env -> rbf``. Returns ------- @@ -320,47 +313,27 @@ def build_edge_cache( ) # === Step 3-5. Edge geometry/RBF chain === - # This segment covers: # gather -> edge_vec -> edge_len -> edge_env -> edge_rbf - # The Triton path is only used on the standard non-compile path when the - # caller explicitly allows it (descriptor eval/inference path). Bridging - # primitives never enter here; they are owned by the sparse-edge path. coord_flat = coord.reshape(nf * nall, 3) - use_bessel_triton = ( - use_geometry_rbf_triton - and getattr(radial_basis, "basis_type", "bessel") == "bessel" - ) - if use_bessel_triton: - with nvtx_range("edge_geometry_rbf_triton"): - edge_vec, edge_len, edge_env, edge_rbf = edge_geometry_rbf_triton( - coord_flat=coord_flat, - center_coord_index=center_coord_index, - neighbor_coord_index=neighbor_coord_index, - edge_envelope=edge_envelope, - radial_basis=radial_basis, - eps=eps, - inner_clamp=None, - ) - else: - # === Step 3. Gather per-edge geometry === - # edge_vec points from center -> neighbor: r_ij = r_j - r_i (in Å). - # edge_len is the scalar distance. - with nvtx_range("edge_geom"): - center_pos = coord_flat.index_select(0, center_coord_index) - neighbor_pos = coord_flat.index_select(0, neighbor_coord_index) - edge_vec = neighbor_pos - center_pos # (E, 3) - edge_len = safe_norm(edge_vec, eps) # (E, 1) - - # === Step 4. C^3 envelope weight === - # Edges with r >= rcut are not removed from the cache. Their envelope is - # exactly zero, so messages vanish naturally while degree normalization - # remains smooth at the cutoff boundary. - with nvtx_range("envelope"): - edge_env = edge_envelope(edge_len) # (E, 1) - - # === Step 5. Radial basis (envelope already baked in) === - with nvtx_range("radial_basis"): - edge_rbf = radial_basis(edge_len) # (E, n_radial) + # === Step 3. Gather per-edge geometry === + # edge_vec points from center -> neighbor: r_ij = r_j - r_i (in Å). + # edge_len is the scalar distance. + with nvtx_range("edge_geom"): + center_pos = coord_flat.index_select(0, center_coord_index) + neighbor_pos = coord_flat.index_select(0, neighbor_coord_index) + edge_vec = neighbor_pos - center_pos # (E, 3) + edge_len = safe_norm(edge_vec, eps) # (E, 1) + + # === Step 4. C^3 envelope weight === + # Edges with r >= rcut are not removed from the cache. Their envelope is + # exactly zero, so messages vanish naturally while degree normalization + # remains smooth at the cutoff boundary. + with nvtx_range("envelope"): + edge_env = edge_envelope(edge_len) # (E, 1) + + # === Step 5. Radial basis (envelope already baked in) === + with nvtx_range("radial_basis"): + edge_rbf = radial_basis(edge_len) # (E, n_radial) # === Step 6. Edge quaternion -> Wigner-D blocks === with nvtx_range("wigner_d"): diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py index efa175dcef..a36cb25809 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/so2.py +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -11,6 +11,7 @@ ) import math +import os from typing import ( TYPE_CHECKING, Any, @@ -71,10 +72,8 @@ SO3Linear, ) from .triton import ( - resolve_triton_rotation_mode, - rotate_back_triton, - rotate_to_local_triton, - sezm_triton_enabled, + rotate_back, + rotate_to_local, ) from .utils import ( ATTN_RES_MODES, @@ -794,9 +793,6 @@ class SO2Convolution(nn.Module): mlp_bias Whether to use bias in SO2Linear (l=0 bias) and GatedActivation (gate linear bias). - use_triton - If True, opt into fused Triton SO(2) rotation kernels on supported - CUDA dtypes. The eager projection path remains the default. radial_so2_mode Dynamic radial degree mixer mode. ``"none"`` applies elementwise radial modulation, ``"degree"`` applies a channel-shared dynamic @@ -845,7 +841,6 @@ def __init__( lebedev_quadrature: bool = False, activation_function: str = "silu", mlp_bias: bool = False, - use_triton: bool = False, radial_so2_mode: str = "none", radial_so2_rank: int = 0, eps: float = 1e-7, @@ -941,7 +936,6 @@ def __init__( else int(self.attn_focus_dim // self.n_atten_head) ) self.mlp_bias = bool(mlp_bias) - self.use_triton = bool(use_triton) self.radial_so2_mode = str(radial_so2_mode).lower() if self.radial_so2_mode not in {"none", "degree", "degree_channel"}: raise ValueError( @@ -956,10 +950,14 @@ def __init__( self.device = env.DEVICE self.precision = RESERVED_PRECISION_DICT[dtype] self.compute_dtype = get_promoted_dtype(self.dtype) - self.use_triton_rotations = self.use_triton and sezm_triton_enabled( - device=self.device, - dtype=self.dtype, - ) + # Optional Triton rotation kernels for the SO(2) convolution, enabled by + # ``DP_TRITON_INFER=1`` (default disabled, in which case the dense + # ``bmm`` rotation is used). The flag is read once at construction so it + # is a compile-time constant in the traced (``make_fx``) graph, and it + # only takes effect during inference. + self.use_triton_infer = os.environ.get( + "DP_TRITON_INFER", "0" + ).strip().lower() in ("1", "true", "yes", "on") # === Step 1. Precompute coefficient indices for m-major reduced layout === coeff_index_m = build_m_major_index(self.lmax, self.mmax, device=self.device) @@ -978,10 +976,6 @@ def __init__( "rotate_inv_rescale_full", rotate_inv_rescale_full, persistent=True ) self.reduced_dim = int(coeff_index_m.numel()) - self.triton_rotation_mode = resolve_triton_rotation_mode( - dim_full=self.ebed_dim_full, - reduced_dim=self.reduced_dim, - ) # === Step 2. Split deterministic seeds at the module top-level === seed_so2_stack = child_seed(seed, 0) @@ -1448,23 +1442,27 @@ def forward( with nvtx_range("SO2Conv/rotate_to_local"): D_full = edge_cache.D_full x_dst_local: torch.Tensor | None = None - if self.use_triton_rotations and not self.training: - x_local = rotate_to_local_triton( - x=x, - src=src, - wigner=D_full, - coeff_index=self.coeff_index_m, - dim_full=self.ebed_dim_full, - rotation_mode=self.triton_rotation_mode, + if self.use_triton_infer and not self.training: + # ``rotate_to_local`` / ``rotate_back`` pick the kernel from the + # coefficient layout, not from this flag: the block-diagonal + # kernel for the canonical m-major ``mmax == 1`` layout used here + # (with ``lmax`` inferred from ``ebed_dim_full``), and a dense + # kernel for any other ``(lmax, mmax)``. Both compose with the + # traced force path through their functional custom-op autograd. + x_local = rotate_to_local( + x, + src, + D_full, + self.coeff_index_m, + self.ebed_dim_full, ) # (E, D_m, C_wide) if self.node_wise_grid_product is not None: - x_dst_local = rotate_to_local_triton( - x=x, - src=dst, - wigner=D_full, - coeff_index=self.coeff_index_m, - dim_full=self.ebed_dim_full, - rotation_mode=self.triton_rotation_mode, + x_dst_local = rotate_to_local( + x, + dst, + D_full, + self.coeff_index_m, + self.ebed_dim_full, ) # (E, D_m, C_wide) else: D_m_prime = project_D_to_m( @@ -1621,13 +1619,12 @@ def apply_bias_correction( # === Step 7. Rotate back to global frame === with nvtx_range("SO2Conv/rotate_back"): Dt_full = edge_cache.Dt_full - if self.use_triton_rotations and not self.training: - x_message = rotate_back_triton( - x_local=x_local, - wigner=Dt_full, - coeff_index=self.coeff_index_m, - dim_full=self.ebed_dim_full, - rotation_mode=self.triton_rotation_mode, + if self.use_triton_infer and not self.training: + x_message = rotate_back( + x_local, + Dt_full, + self.coeff_index_m, + self.ebed_dim_full, ) # (E, D, C_wide) else: Dt_from_m = project_Dt_from_m( diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py index 5c80f7824d..eec36931e7 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py @@ -1,28 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""Public Triton entry points for SeZM SO(2) rotations.""" +"""Hardware-accelerated SeZM/DPA4 operators. -from .autograd import ( - edge_geometry_rbf_triton, - rotate_back_triton, - rotate_to_local_triton, -) -from .constants import ( - SEZM_TRITON_AVAILABLE, - TritonRotationMode, -) -from .dispatch import ( - resolve_triton_rotation_mode, - sezm_triton_enabled, - uses_triton_kernel, +This package hosts clean, ``torch.compile``-composable Triton implementations of +SeZM hot paths. The first member is the fused SO(2)/Wigner rotation pair used by +the SO(2) convolution (``rotate_to_local`` / ``rotate_back``). +""" + +from .so2_rotation import ( + rotate_back, + rotate_to_local, ) __all__ = [ - "SEZM_TRITON_AVAILABLE", - "TritonRotationMode", - "edge_geometry_rbf_triton", - "resolve_triton_rotation_mode", - "rotate_back_triton", - "rotate_to_local_triton", - "sezm_triton_enabled", - "uses_triton_kernel", + "rotate_back", + "rotate_to_local", ] diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py b/deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py deleted file mode 100644 index dd1c9bbc06..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py +++ /dev/null @@ -1,837 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -"""Autograd and public API for SeZM Triton kernels.""" - -from __future__ import ( - annotations, -) - -from typing import ( - Any, -) - -import torch -from torch import ( - Tensor, -) - -from ..utils import ( - safe_norm, -) -from .constants import ( - SEZM_TRITON_AVAILABLE, - TritonRotationMode, -) -from .dispatch import ( - coerce_rotation_mode, - resolve_triton_rotation_mode, -) - -if SEZM_TRITON_AVAILABLE: - from . import custom_ops as _custom_ops # noqa: F401 - - -def _compute_cutoff_envelope_eager( - *, - r: Tensor, - rcut: float, - a: float, - b: float, - c: float, - d: float, - exponent: int, -) -> Tensor: - """Reference eager evaluation of the C^3 cutoff envelope.""" - x = (r / rcut).clamp(min=0.0, max=1.0) - poly = a + x * (b + x * (c + x * d)) - env = 1.0 + (x ** int(exponent)) * poly - return env * (x < 1.0).to(dtype=r.dtype) - - -def _edge_geometry_rbf_eager( - *, - coord_flat: Tensor, - center_coord_index: Tensor, - neighbor_coord_index: Tensor, - freqs: Tensor, - eps: float, - rcut: float, - edge_env_a: float, - edge_env_b: float, - edge_env_c: float, - edge_env_d: float, - edge_env_exponent: int, - radial_env_a: float, - radial_env_b: float, - radial_env_c: float, - radial_env_d: float, - radial_env_exponent: int, - r_inner: float, - r_outer: float, - has_inner_clamp: bool, -) -> tuple[Tensor, Tensor, Tensor, Tensor]: - """Reference eager implementation of the edge geometry/RBF chain.""" - center_pos = coord_flat.index_select(0, center_coord_index) - neighbor_pos = coord_flat.index_select(0, neighbor_coord_index) - edge_vec = neighbor_pos - center_pos - raw_len = safe_norm(edge_vec, float(eps)) - edge_len = raw_len - if has_inner_clamp: - delta = float(r_outer - r_inner) - t = ((edge_len - float(r_inner)) / delta).clamp(0.0, 1.0) - t2 = t * t - t4 = t2 * t2 - h = t4 * (20.0 + t * (-45.0 + t * (36.0 - 10.0 * t))) - clamped = float(r_inner) + delta * h - edge_len = torch.where(edge_len >= float(r_outer), edge_len, clamped) - scale = edge_len / raw_len - edge_vec = edge_vec * scale - edge_env = _compute_cutoff_envelope_eager( - r=edge_len, - rcut=float(rcut), - a=float(edge_env_a), - b=float(edge_env_b), - c=float(edge_env_c), - d=float(edge_env_d), - exponent=int(edge_env_exponent), - ) - radial_env = _compute_cutoff_envelope_eager( - r=edge_len, - rcut=float(rcut), - a=float(radial_env_a), - b=float(radial_env_b), - c=float(radial_env_c), - d=float(radial_env_d), - exponent=int(radial_env_exponent), - ) - freqs_row = freqs.view(1, -1) - phase = edge_len * freqs_row - edge_rbf = freqs_row * torch.sinc(phase / torch.pi) * radial_env - return edge_vec, edge_len, edge_env, edge_rbf - - -def _extract_envelope_params( - envelope: Any, -) -> tuple[float, float, float, float, float, int]: - """Extract the polynomial envelope parameters from one SeZM module.""" - return ( - float(envelope.rcut), - float(envelope.coeff_a), - float(envelope.coeff_b), - float(envelope.coeff_c), - float(envelope.coeff_d), - int(envelope.p), - ) - - -def _extract_edge_geometry_rbf_constants( - *, - edge_envelope: Any, - radial_basis: Any, - inner_clamp: Any, -) -> tuple[ - float, - float, - float, - float, - float, - int, - float, - float, - float, - float, - int, - float, - float, - bool, -]: - """Extract scalar constants used by the fused geometry/RBF chain.""" - ( - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - edge_env_exponent, - ) = _extract_envelope_params(edge_envelope) - ( - _, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - radial_env_exponent, - ) = _extract_envelope_params(radial_basis.envelope) - if inner_clamp is None: - r_inner = 0.0 - r_outer = 0.0 - has_inner_clamp = False - else: - r_inner = float(inner_clamp.r_inner) - r_outer = float(inner_clamp.r_outer) - has_inner_clamp = True - return ( - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - edge_env_exponent, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - radial_env_exponent, - r_inner, - r_outer, - has_inner_clamp, - ) - - -def _rotate_to_local_eager( - *, - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - """Reference eager implementation for ``D_to_m @ x[src]``.""" - D_to_m = wigner[:, :dim_full, :dim_full].index_select(1, coeff_index) - return torch.bmm(D_to_m, x.index_select(0, src)) - - -def _rotate_back_eager( - *, - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - """Reference eager implementation for ``Dt_from_m @ x_local``.""" - Dt_from_m = wigner[:, :dim_full, :dim_full].index_select(2, coeff_index) - return torch.bmm(Dt_from_m, x_local) - - -def _resolve_rotation_mode_for_call( - *, - dim_full: int, - coeff_index: Tensor, - rotation_mode: int | TritonRotationMode | None, -) -> TritonRotationMode: - """Resolve the effective dispatch mode for one public API call.""" - if rotation_mode is None: - return resolve_triton_rotation_mode( - dim_full=int(dim_full), - reduced_dim=int(coeff_index.numel()), - ) - return coerce_rotation_mode(rotation_mode) - - -if SEZM_TRITON_AVAILABLE: - - class _RotateToLocalFunction(torch.autograd.Function): - """Autograd wrapper for the fused ``global -> local reduced`` rotation.""" - - @staticmethod - def forward( - ctx: Any, - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - rotation_mode: int, - ) -> Tensor: - reduced_dim = int(coeff_index.numel()) - out = torch.empty( - src.shape[0], - reduced_dim, - x.shape[2], - dtype=x.dtype, - device=x.device, - ) - torch.ops.deepmd._kernel_sezm_rotate_to_local( - x, - src, - wigner, - coeff_index, - out, - dim_full, - rotation_mode, - ) - ctx.save_for_backward(x, src, wigner, coeff_index) - ctx.dim_full = int(dim_full) - ctx.rotation_mode = int(rotation_mode) - return out - - @staticmethod - def backward( - ctx: Any, - grad_out: Tensor, - ) -> tuple[Tensor, None, Tensor, None, None, None]: - x, src, wigner, coeff_index = ctx.saved_tensors - dim_full = int(ctx.dim_full) - rotation_mode = coerce_rotation_mode(int(ctx.rotation_mode)) - grad_out = grad_out.contiguous() - grad_edge = torch.empty( - src.shape[0], - dim_full, - x.shape[2], - dtype=grad_out.dtype, - device=grad_out.device, - ) - torch.ops.deepmd._kernel_sezm_rotate_to_local_bwd_dx( - grad_out, - wigner, - coeff_index, - grad_edge, - dim_full, - int(rotation_mode), - ) - grad_x = torch.zeros_like(x) - grad_x.index_add_(0, src, grad_edge) - - if rotation_mode == TritonRotationMode.GENERIC_TILED: - grad_rows = torch.empty( - src.shape[0], - coeff_index.numel(), - dim_full, - dtype=wigner.dtype, - device=grad_out.device, - ) - torch.ops.deepmd._kernel_sezm_rotate_to_local_bwd_dw( - grad_out, - x, - src, - coeff_index, - grad_rows, - dim_full, - int(rotation_mode), - ) - grad_wigner = torch.zeros_like(wigner) - grad_wigner[:, coeff_index, :dim_full] = grad_rows - else: - grad_wigner = torch.zeros_like(wigner) - torch.ops.deepmd._kernel_sezm_rotate_to_local_bwd_dw( - grad_out, - x, - src, - coeff_index, - grad_wigner, - dim_full, - int(rotation_mode), - ) - return grad_x, None, grad_wigner, None, None, None - - class _RotateBackFunction(torch.autograd.Function): - """Autograd wrapper for the fused ``local reduced -> global`` rotation.""" - - @staticmethod - def forward( - ctx: Any, - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - rotation_mode: int, - ) -> Tensor: - out = torch.empty( - x_local.shape[0], - dim_full, - x_local.shape[2], - dtype=x_local.dtype, - device=x_local.device, - ) - torch.ops.deepmd._kernel_sezm_rotate_back( - x_local, - wigner, - coeff_index, - out, - dim_full, - rotation_mode, - ) - ctx.save_for_backward(x_local, wigner, coeff_index) - ctx.dim_full = int(dim_full) - ctx.rotation_mode = int(rotation_mode) - return out - - @staticmethod - def backward( - ctx: Any, - grad_out: Tensor, - ) -> tuple[Tensor, Tensor, None, None, None]: - x_local, wigner, coeff_index = ctx.saved_tensors - dim_full = int(ctx.dim_full) - rotation_mode = coerce_rotation_mode(int(ctx.rotation_mode)) - grad_out = grad_out.contiguous() - grad_x_local = torch.empty_like(x_local) - torch.ops.deepmd._kernel_sezm_rotate_back_bwd_dx( - grad_out, - wigner, - coeff_index, - grad_x_local, - dim_full, - int(rotation_mode), - ) - - if rotation_mode == TritonRotationMode.GENERIC_TILED: - grad_cols = torch.empty( - x_local.shape[0], - dim_full, - coeff_index.numel(), - dtype=wigner.dtype, - device=grad_out.device, - ) - torch.ops.deepmd._kernel_sezm_rotate_back_bwd_dw( - grad_out, - x_local, - coeff_index, - grad_cols, - dim_full, - int(rotation_mode), - ) - grad_wigner = torch.zeros_like(wigner) - grad_wigner[:, :dim_full, coeff_index] = grad_cols - else: - grad_wigner = torch.zeros_like(wigner) - torch.ops.deepmd._kernel_sezm_rotate_back_bwd_dw( - grad_out, - x_local, - coeff_index, - grad_wigner, - dim_full, - int(rotation_mode), - ) - return grad_x_local, grad_wigner, None, None, None - - class _EdgeGeometryRBFFunction(torch.autograd.Function): - """Autograd wrapper for the fused edge geometry/RBF chain.""" - - @staticmethod - def forward( - ctx: Any, - coord_flat: Tensor, - center_coord_index: Tensor, - neighbor_coord_index: Tensor, - freqs: Tensor, - eps: float, - rcut: float, - edge_env_a: float, - edge_env_b: float, - edge_env_c: float, - edge_env_d: float, - edge_env_exponent: int, - radial_env_a: float, - radial_env_b: float, - radial_env_c: float, - radial_env_d: float, - radial_env_exponent: int, - r_inner: float, - r_outer: float, - has_inner_clamp: bool, - ) -> tuple[Tensor, Tensor, Tensor, Tensor]: - freq_flat = freqs.reshape(-1) - num_edges = int(center_coord_index.shape[0]) - edge_vec = torch.empty( - num_edges, - 3, - dtype=coord_flat.dtype, - device=coord_flat.device, - ) - edge_len = torch.empty( - num_edges, - dtype=coord_flat.dtype, - device=coord_flat.device, - ) - edge_env = torch.empty( - num_edges, - dtype=coord_flat.dtype, - device=coord_flat.device, - ) - edge_rbf = torch.empty( - num_edges, - freq_flat.numel(), - dtype=coord_flat.dtype, - device=coord_flat.device, - ) - torch.ops.deepmd._kernel_sezm_edge_geometry_rbf( - coord_flat, - center_coord_index, - neighbor_coord_index, - freq_flat, - edge_vec, - edge_len, - edge_env, - edge_rbf, - float(eps), - float(rcut), - float(edge_env_a), - float(edge_env_b), - float(edge_env_c), - float(edge_env_d), - int(edge_env_exponent), - float(radial_env_a), - float(radial_env_b), - float(radial_env_c), - float(radial_env_d), - int(radial_env_exponent), - float(r_inner), - float(r_outer), - bool(has_inner_clamp), - ) - ctx.save_for_backward( - coord_flat, - center_coord_index, - neighbor_coord_index, - freqs, - ) - ctx.eps = float(eps) - ctx.rcut = float(rcut) - ctx.edge_env_a = float(edge_env_a) - ctx.edge_env_b = float(edge_env_b) - ctx.edge_env_c = float(edge_env_c) - ctx.edge_env_d = float(edge_env_d) - ctx.edge_env_exponent = int(edge_env_exponent) - ctx.radial_env_a = float(radial_env_a) - ctx.radial_env_b = float(radial_env_b) - ctx.radial_env_c = float(radial_env_c) - ctx.radial_env_d = float(radial_env_d) - ctx.radial_env_exponent = int(radial_env_exponent) - ctx.r_inner = float(r_inner) - ctx.r_outer = float(r_outer) - ctx.has_inner_clamp = bool(has_inner_clamp) - return edge_vec, edge_len.unsqueeze(-1), edge_env.unsqueeze(-1), edge_rbf - - @staticmethod - def backward( - ctx: Any, - grad_edge_vec: Tensor | None, - grad_edge_len: Tensor | None, - grad_edge_env: Tensor | None, - grad_edge_rbf: Tensor | None, - ) -> tuple[ - Tensor, - None, - None, - Tensor, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ]: - coord_flat, center_coord_index, neighbor_coord_index, freqs = ( - ctx.saved_tensors - ) - num_edges = int(center_coord_index.shape[0]) - freq_flat = freqs.reshape(-1) - - if grad_edge_vec is None: - grad_edge_vec = torch.zeros( - num_edges, - 3, - dtype=coord_flat.dtype, - device=coord_flat.device, - ) - else: - grad_edge_vec = grad_edge_vec.contiguous() - if grad_edge_len is None: - grad_edge_len = torch.zeros( - num_edges, - dtype=coord_flat.dtype, - device=coord_flat.device, - ) - else: - grad_edge_len = grad_edge_len.contiguous().squeeze(-1) - if grad_edge_env is None: - grad_edge_env = torch.zeros( - num_edges, - dtype=coord_flat.dtype, - device=coord_flat.device, - ) - else: - grad_edge_env = grad_edge_env.contiguous().squeeze(-1) - if grad_edge_rbf is None: - grad_edge_rbf = torch.zeros( - num_edges, - freq_flat.numel(), - dtype=coord_flat.dtype, - device=coord_flat.device, - ) - else: - grad_edge_rbf = grad_edge_rbf.contiguous() - - grad_r_total = torch.zeros( - num_edges, - dtype=coord_flat.dtype, - device=coord_flat.device, - ) - grad_freq = torch.zeros( - freq_flat.numel(), - dtype=freq_flat.dtype, - device=coord_flat.device, - ) - torch.ops.deepmd._kernel_sezm_edge_geometry_rbf_bwd_accum( - grad_edge_len, - grad_edge_env, - grad_edge_rbf, - coord_flat, - center_coord_index, - neighbor_coord_index, - freq_flat, - grad_r_total, - grad_freq, - float(ctx.eps), - float(ctx.rcut), - float(ctx.edge_env_a), - float(ctx.edge_env_b), - float(ctx.edge_env_c), - float(ctx.edge_env_d), - int(ctx.edge_env_exponent), - float(ctx.radial_env_a), - float(ctx.radial_env_b), - float(ctx.radial_env_c), - float(ctx.radial_env_d), - int(ctx.radial_env_exponent), - float(ctx.r_inner), - float(ctx.r_outer), - bool(ctx.has_inner_clamp), - ) - grad_coord = torch.zeros_like(coord_flat) - torch.ops.deepmd._kernel_sezm_edge_geometry_rbf_bwd_coord( - grad_edge_vec, - grad_r_total, - coord_flat, - center_coord_index, - neighbor_coord_index, - grad_coord, - float(ctx.eps), - float(ctx.r_inner), - float(ctx.r_outer), - bool(ctx.has_inner_clamp), - ) - return ( - grad_coord, - None, - None, - grad_freq.view_as(freqs), - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -def rotate_to_local_triton( - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - rotation_mode: int | TritonRotationMode | None = None, -) -> Tensor: - """ - Apply the fused ``global -> local reduced`` rotation. - - Parameters - ---------- - x - Node features with shape ``(N, D, C)``. - src - Source-node indices with shape ``(E,)``. - wigner - Packed Wigner matrices with shape ``(E, D, D)``. - coeff_index - Reduced-layout row indices with shape ``(D_m,)``. - dim_full - Full packed SO(3) dimension. - rotation_mode - Optional pre-resolved dispatch mode. - - Returns - ------- - Tensor - Rotated reduced-layout edge features with shape ``(E, D_m, C)``. - """ - if not SEZM_TRITON_AVAILABLE: - raise RuntimeError("SeZM Triton kernels are not available in this environment.") - src = src.contiguous() - coeff_index = coeff_index.contiguous() - resolved_mode = _resolve_rotation_mode_for_call( - dim_full=int(dim_full), - coeff_index=coeff_index, - rotation_mode=rotation_mode, - ) - if resolved_mode == TritonRotationMode.EAGER_REFERENCE: - return _rotate_to_local_eager( - x=x, - src=src, - wigner=wigner, - coeff_index=coeff_index, - dim_full=int(dim_full), - ) - return _RotateToLocalFunction.apply( - x, - src, - wigner, - coeff_index, - int(dim_full), - int(resolved_mode), - ) - - -def rotate_back_triton( - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, - rotation_mode: int | TritonRotationMode | None = None, -) -> Tensor: - """ - Apply the fused ``local reduced -> global`` rotation. - - Parameters - ---------- - x_local - Reduced-layout edge features with shape ``(E, D_m, C)``. - wigner - Packed Wigner matrices with shape ``(E, D, D)``. - coeff_index - Reduced-layout column indices with shape ``(D_m,)``. - dim_full - Full packed SO(3) dimension. - rotation_mode - Optional pre-resolved dispatch mode. - - Returns - ------- - Tensor - Lifted global-layout edge features with shape ``(E, D, C)``. - """ - if not SEZM_TRITON_AVAILABLE: - raise RuntimeError("SeZM Triton kernels are not available in this environment.") - coeff_index = coeff_index.contiguous() - resolved_mode = _resolve_rotation_mode_for_call( - dim_full=int(dim_full), - coeff_index=coeff_index, - rotation_mode=rotation_mode, - ) - if resolved_mode == TritonRotationMode.EAGER_REFERENCE: - return _rotate_back_eager( - x_local=x_local, - wigner=wigner, - coeff_index=coeff_index, - dim_full=int(dim_full), - ) - return _RotateBackFunction.apply( - x_local, - wigner, - coeff_index, - int(dim_full), - int(resolved_mode), - ) - - -def edge_geometry_rbf_triton( - *, - coord_flat: Tensor, - center_coord_index: Tensor, - neighbor_coord_index: Tensor, - edge_envelope: Any, - radial_basis: Any, - eps: float, - inner_clamp: Any, -) -> tuple[Tensor, Tensor, Tensor, Tensor]: - """Apply the fused edge geometry/RBF chain with eager fallback.""" - ( - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - edge_env_exponent, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - radial_env_exponent, - r_inner, - r_outer, - has_inner_clamp, - ) = _extract_edge_geometry_rbf_constants( - edge_envelope=edge_envelope, - radial_basis=radial_basis, - inner_clamp=inner_clamp, - ) - center_coord_index = center_coord_index.contiguous() - neighbor_coord_index = neighbor_coord_index.contiguous() - freqs = radial_basis.adam_freqs.contiguous() - if ( - center_coord_index.numel() == 0 - or not SEZM_TRITON_AVAILABLE - or coord_flat.device.type != "cuda" - or coord_flat.dtype not in (torch.float16, torch.bfloat16, torch.float32) - ): - return _edge_geometry_rbf_eager( - coord_flat=coord_flat, - center_coord_index=center_coord_index, - neighbor_coord_index=neighbor_coord_index, - freqs=freqs, - eps=float(eps), - rcut=rcut, - edge_env_a=edge_env_a, - edge_env_b=edge_env_b, - edge_env_c=edge_env_c, - edge_env_d=edge_env_d, - edge_env_exponent=edge_env_exponent, - radial_env_a=radial_env_a, - radial_env_b=radial_env_b, - radial_env_c=radial_env_c, - radial_env_d=radial_env_d, - radial_env_exponent=radial_env_exponent, - r_inner=r_inner, - r_outer=r_outer, - has_inner_clamp=has_inner_clamp, - ) - return _EdgeGeometryRBFFunction.apply( - coord_flat, - center_coord_index, - neighbor_coord_index, - freqs, - float(eps), - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - edge_env_exponent, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - radial_env_exponent, - r_inner, - r_outer, - has_inner_clamp, - ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/constants.py b/deepmd/pt/model/descriptor/sezm_nn/triton/constants.py deleted file mode 100644 index c2aabb8147..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/constants.py +++ /dev/null @@ -1,46 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -"""Shared constants and feature flags for SeZM Triton kernels.""" - -from __future__ import ( - annotations, -) - -from enum import ( - IntEnum, -) - -import torch - -_HAS_TORCH_TRITON_OP = hasattr(torch.library, "triton_op") and hasattr( - torch.library, "wrap_triton" -) - -if _HAS_TORCH_TRITON_OP: - try: - import triton # noqa: F401 - except ImportError: - SEZM_TRITON_AVAILABLE = False - else: - SEZM_TRITON_AVAILABLE = True -else: - SEZM_TRITON_AVAILABLE = False - -# Triton dot kernels require K >= 16 on the current CUDA backend. -TRITON_GRID_E_STRIDE = 2048 -TRITON_BLOCK_FULL = 16 -TRITON_BLOCK_REDUCED = 16 -TRITON_BLOCK_CHANNEL = 32 -TRITON_SMALL_BLOCK_CHANNEL = 128 -TRITON_SMALL_FULL_DIM = 16 -TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE = 128 -TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL = 16 - - -class TritonRotationMode(IntEnum): - """Dispatch mode for the SeZM rotation hot path.""" - - GENERIC_TILED = 0 - SMALL_LE1 = 1 - SMALL_L2 = 2 - SMALL_L3 = 3 - EAGER_REFERENCE = 4 diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py b/deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py deleted file mode 100644 index 23b31aa2f5..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py +++ /dev/null @@ -1,861 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -"""Triton custom-op launchers for SeZM SO(2) rotation kernels. - -This layer only decides how to launch a resolved dispatch mode. Fallback policy -stays in the public autograd API so the launchers remain focused on Triton -grids, kernel families, and argument packing. -""" - -from __future__ import ( - annotations, -) - -import torch # noqa: TC002 - -from .constants import ( - SEZM_TRITON_AVAILABLE, - TRITON_BLOCK_CHANNEL, - TRITON_BLOCK_FULL, - TRITON_BLOCK_REDUCED, - TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, - TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL, - TRITON_GRID_E_STRIDE, - TRITON_SMALL_BLOCK_CHANNEL, - TritonRotationMode, -) -from .dispatch import ( - coerce_rotation_mode, -) - - -def _require_kernel_mode( - rotation_mode: int | TritonRotationMode, -) -> TritonRotationMode: - """Reject eager fallback before entering the Triton launch layer.""" - resolved_mode = coerce_rotation_mode(rotation_mode) - if resolved_mode == TritonRotationMode.EAGER_REFERENCE: - raise ValueError("Eager reference mode must be handled before Triton launch.") - return resolved_mode - - -if SEZM_TRITON_AVAILABLE: - from torch.library import ( - triton_op, - wrap_triton, - ) - - from .kernels_edge_geometry_rbf import ( - edge_geometry_rbf_bwd_accum_kernel, - edge_geometry_rbf_bwd_coord_kernel, - edge_geometry_rbf_forward_kernel, - ) - from .kernels_generic import ( - rotate_back_bwd_dw_kernel, - rotate_back_bwd_dx_kernel, - rotate_back_forward_kernel, - rotate_to_local_bwd_dw_kernel, - rotate_to_local_bwd_dx_kernel, - rotate_to_local_forward_kernel, - ) - from .kernels_small import ( - rotate_back_l1_bwd_dx_kernel, - rotate_back_l1_forward_kernel, - rotate_back_l2_bwd_dx_kernel, - rotate_back_l2_forward_kernel, - rotate_back_l3_bwd_dx_kernel, - rotate_back_l3_forward_kernel, - rotate_back_small_bwd_dw_kernel, - rotate_to_local_l1_bwd_dx_kernel, - rotate_to_local_l1_forward_kernel, - rotate_to_local_l2_bwd_dx_kernel, - rotate_to_local_l2_forward_kernel, - rotate_to_local_l3_bwd_dx_kernel, - rotate_to_local_l3_forward_kernel, - rotate_to_local_small_bwd_dw_kernel, - ) - - _ROTATE_TO_LOCAL_SMALL_FORWARD = { - TritonRotationMode.SMALL_LE1: rotate_to_local_l1_forward_kernel, - TritonRotationMode.SMALL_L2: rotate_to_local_l2_forward_kernel, - TritonRotationMode.SMALL_L3: rotate_to_local_l3_forward_kernel, - } - _ROTATE_TO_LOCAL_SMALL_BWD_DX = { - TritonRotationMode.SMALL_LE1: rotate_to_local_l1_bwd_dx_kernel, - TritonRotationMode.SMALL_L2: rotate_to_local_l2_bwd_dx_kernel, - TritonRotationMode.SMALL_L3: rotate_to_local_l3_bwd_dx_kernel, - } - _ROTATE_BACK_SMALL_FORWARD = { - TritonRotationMode.SMALL_LE1: rotate_back_l1_forward_kernel, - TritonRotationMode.SMALL_L2: rotate_back_l2_forward_kernel, - TritonRotationMode.SMALL_L3: rotate_back_l3_forward_kernel, - } - _ROTATE_BACK_SMALL_BWD_DX = { - TritonRotationMode.SMALL_LE1: rotate_back_l1_bwd_dx_kernel, - TritonRotationMode.SMALL_L2: rotate_back_l2_bwd_dx_kernel, - TritonRotationMode.SMALL_L3: rotate_back_l3_bwd_dx_kernel, - } - - def _small_channel_grid(channels: int) -> tuple[int, int]: - """Return the standard ``(edge, channel)`` grid for small kernels.""" - return ( - TRITON_GRID_E_STRIDE, - (channels + TRITON_SMALL_BLOCK_CHANNEL - 1) // TRITON_SMALL_BLOCK_CHANNEL, - ) - - def _generic_rotate_to_local_forward_grid( - reduced_dim: int, - channels: int, - ) -> tuple[int, int, int]: - """Return the standard forward grid for generic rotate-to-local.""" - return ( - TRITON_GRID_E_STRIDE, - (reduced_dim + TRITON_BLOCK_REDUCED - 1) // TRITON_BLOCK_REDUCED, - (channels + TRITON_BLOCK_CHANNEL - 1) // TRITON_BLOCK_CHANNEL, - ) - - def _generic_rotate_to_local_bwd_dx_grid( - dim_full: int, - channels: int, - ) -> tuple[int, int, int]: - """Return the source-gradient grid for generic rotate-to-local.""" - return ( - TRITON_GRID_E_STRIDE, - (dim_full + TRITON_BLOCK_FULL - 1) // TRITON_BLOCK_FULL, - (channels + TRITON_BLOCK_CHANNEL - 1) // TRITON_BLOCK_CHANNEL, - ) - - def _generic_rotate_to_local_bwd_dw_grid( - reduced_dim: int, - dim_full: int, - ) -> tuple[int, int, int]: - """Return the Wigner-gradient grid for generic rotate-to-local.""" - return ( - TRITON_GRID_E_STRIDE, - (reduced_dim + TRITON_BLOCK_REDUCED - 1) // TRITON_BLOCK_REDUCED, - (dim_full + TRITON_BLOCK_FULL - 1) // TRITON_BLOCK_FULL, - ) - - def _generic_rotate_back_forward_grid( - dim_full: int, - channels: int, - ) -> tuple[int, int, int]: - """Return the standard forward grid for generic rotate-back.""" - return ( - TRITON_GRID_E_STRIDE, - (dim_full + TRITON_BLOCK_FULL - 1) // TRITON_BLOCK_FULL, - (channels + TRITON_BLOCK_CHANNEL - 1) // TRITON_BLOCK_CHANNEL, - ) - - def _generic_rotate_back_bwd_dx_grid( - reduced_dim: int, - channels: int, - ) -> tuple[int, int, int]: - """Return the reduced-gradient grid for generic rotate-back.""" - return ( - TRITON_GRID_E_STRIDE, - (reduced_dim + TRITON_BLOCK_REDUCED - 1) // TRITON_BLOCK_REDUCED, - (channels + TRITON_BLOCK_CHANNEL - 1) // TRITON_BLOCK_CHANNEL, - ) - - def _generic_rotate_back_bwd_dw_grid( - dim_full: int, - reduced_dim: int, - ) -> tuple[int, int, int]: - """Return the Wigner-gradient grid for generic rotate-back.""" - return ( - TRITON_GRID_E_STRIDE, - (dim_full + TRITON_BLOCK_FULL - 1) // TRITON_BLOCK_FULL, - (reduced_dim + TRITON_BLOCK_REDUCED - 1) // TRITON_BLOCK_REDUCED, - ) - - def _edge_geometry_rbf_grid(num_edges: int, n_radial: int) -> tuple[int, int]: - """Return the standard grid for the fused edge geometry/RBF chain.""" - return ( - (num_edges + TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE - 1) - // TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, - (n_radial + TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL - 1) - // TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL, - ) - - def _edge_geometry_rbf_coord_grid(num_edges: int) -> tuple[int]: - """Return the edge-only grid for geometry/RBF coordinate gradients.""" - return ( - (num_edges + TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE - 1) - // TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, - ) - - def _launch_rotate_to_local_small_forward( - *, - rotation_mode: TritonRotationMode, - x: torch.Tensor, - src: torch.Tensor, - wigner: torch.Tensor, - coeff_index: torch.Tensor, - out: torch.Tensor, - dim_full: int, - ) -> None: - """Launch one specialized small-family rotate-to-local forward kernel.""" - reduced_dim = coeff_index.numel() - channels = x.shape[2] - kernel = _ROTATE_TO_LOCAL_SMALL_FORWARD[rotation_mode] - wrap_triton(kernel)[_small_channel_grid(channels)]( - x, - src, - wigner, - coeff_index, - out, - src.shape[0], - reduced_dim, - dim_full, - channels, - x.stride(0), - x.stride(1), - x.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - def _launch_rotate_to_local_small_bwd_dx( - *, - rotation_mode: TritonRotationMode, - grad_out: torch.Tensor, - wigner: torch.Tensor, - coeff_index: torch.Tensor, - grad_edge: torch.Tensor, - dim_full: int, - ) -> None: - """Launch one specialized small-family rotate-to-local dx kernel.""" - reduced_dim = coeff_index.numel() - channels = grad_out.shape[2] - kernel = _ROTATE_TO_LOCAL_SMALL_BWD_DX[rotation_mode] - wrap_triton(kernel)[_small_channel_grid(channels)]( - grad_out, - wigner, - coeff_index, - grad_edge, - grad_out.shape[0], - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - grad_edge.stride(0), - grad_edge.stride(1), - grad_edge.stride(2), - BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - def _launch_rotate_back_small_forward( - *, - rotation_mode: TritonRotationMode, - x_local: torch.Tensor, - wigner: torch.Tensor, - coeff_index: torch.Tensor, - out: torch.Tensor, - dim_full: int, - ) -> None: - """Launch one specialized small-family rotate-back forward kernel.""" - reduced_dim = coeff_index.numel() - channels = x_local.shape[2] - kernel = _ROTATE_BACK_SMALL_FORWARD[rotation_mode] - wrap_triton(kernel)[_small_channel_grid(channels)]( - x_local, - wigner, - coeff_index, - out, - x_local.shape[0], - reduced_dim, - dim_full, - channels, - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - def _launch_rotate_back_small_bwd_dx( - *, - rotation_mode: TritonRotationMode, - grad_out: torch.Tensor, - wigner: torch.Tensor, - coeff_index: torch.Tensor, - grad_x_local: torch.Tensor, - dim_full: int, - ) -> None: - """Launch one specialized small-family rotate-back dx kernel.""" - reduced_dim = coeff_index.numel() - channels = grad_out.shape[2] - kernel = _ROTATE_BACK_SMALL_BWD_DX[rotation_mode] - wrap_triton(kernel)[_small_channel_grid(channels)]( - grad_out, - wigner, - coeff_index, - grad_x_local, - grad_out.shape[0], - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - grad_x_local.stride(0), - grad_x_local.stride(1), - grad_x_local.stride(2), - BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - @triton_op( - "deepmd::_kernel_sezm_rotate_to_local", - mutates_args=("out",), - ) - def _kernel_sezm_rotate_to_local( - x: torch.Tensor, - src: torch.Tensor, - wigner: torch.Tensor, - coeff_index: torch.Tensor, - out: torch.Tensor, - dim_full: int, - rotation_mode: int, - ) -> None: - """Launch the fused Triton forward kernel for ``D_to_m @ x[src]``.""" - mode = _require_kernel_mode(rotation_mode) - reduced_dim = coeff_index.numel() - channels = x.shape[2] - if mode != TritonRotationMode.GENERIC_TILED: - _launch_rotate_to_local_small_forward( - rotation_mode=mode, - x=x, - src=src, - wigner=wigner, - coeff_index=coeff_index, - out=out, - dim_full=dim_full, - ) - return - wrap_triton(rotate_to_local_forward_kernel)[ - _generic_rotate_to_local_forward_grid(reduced_dim, channels) - ]( - x, - src, - wigner, - coeff_index, - out, - src.shape[0], - reduced_dim, - dim_full, - channels, - x.stride(0), - x.stride(1), - x.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - BLOCK_REDUCED=TRITON_BLOCK_REDUCED, - BLOCK_FULL=TRITON_BLOCK_FULL, - BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - @triton_op( - "deepmd::_kernel_sezm_rotate_to_local_bwd_dx", - mutates_args=("grad_edge",), - ) - def _kernel_sezm_rotate_to_local_bwd_dx( - grad_out: torch.Tensor, - wigner: torch.Tensor, - coeff_index: torch.Tensor, - grad_edge: torch.Tensor, - dim_full: int, - rotation_mode: int, - ) -> None: - """Launch the Triton backward kernel for source-feature gradients.""" - mode = _require_kernel_mode(rotation_mode) - reduced_dim = coeff_index.numel() - channels = grad_out.shape[2] - if mode != TritonRotationMode.GENERIC_TILED: - _launch_rotate_to_local_small_bwd_dx( - rotation_mode=mode, - grad_out=grad_out, - wigner=wigner, - coeff_index=coeff_index, - grad_edge=grad_edge, - dim_full=dim_full, - ) - return - wrap_triton(rotate_to_local_bwd_dx_kernel)[ - _generic_rotate_to_local_bwd_dx_grid(dim_full, channels) - ]( - grad_out, - wigner, - coeff_index, - grad_edge, - grad_out.shape[0], - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - grad_edge.stride(0), - grad_edge.stride(1), - grad_edge.stride(2), - BLOCK_REDUCED=TRITON_BLOCK_REDUCED, - BLOCK_FULL=TRITON_BLOCK_FULL, - BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - @triton_op( - "deepmd::_kernel_sezm_rotate_to_local_bwd_dw", - mutates_args=("grad_wigner",), - ) - def _kernel_sezm_rotate_to_local_bwd_dw( - grad_out: torch.Tensor, - x: torch.Tensor, - src: torch.Tensor, - coeff_index: torch.Tensor, - grad_wigner: torch.Tensor, - dim_full: int, - rotation_mode: int, - ) -> None: - """Launch the Triton backward kernel for Wigner gradients.""" - mode = _require_kernel_mode(rotation_mode) - reduced_dim = coeff_index.numel() - channels = grad_out.shape[2] - if mode != TritonRotationMode.GENERIC_TILED: - wrap_triton(rotate_to_local_small_bwd_dw_kernel)[(TRITON_GRID_E_STRIDE,)]( - grad_out, - x, - src, - coeff_index, - grad_wigner, - grad_out.shape[0], - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x.stride(0), - x.stride(1), - x.stride(2), - grad_wigner.stride(0), - grad_wigner.stride(1), - grad_wigner.stride(2), - BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - return - wrap_triton(rotate_to_local_bwd_dw_kernel)[ - _generic_rotate_to_local_bwd_dw_grid(reduced_dim, dim_full) - ]( - grad_out, - x, - src, - coeff_index, - grad_wigner, - grad_out.shape[0], - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x.stride(0), - x.stride(1), - x.stride(2), - grad_wigner.stride(0), - grad_wigner.stride(1), - grad_wigner.stride(2), - BLOCK_REDUCED=TRITON_BLOCK_REDUCED, - BLOCK_FULL=TRITON_BLOCK_FULL, - BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - @triton_op( - "deepmd::_kernel_sezm_rotate_back", - mutates_args=("out",), - ) - def _kernel_sezm_rotate_back( - x_local: torch.Tensor, - wigner: torch.Tensor, - coeff_index: torch.Tensor, - out: torch.Tensor, - dim_full: int, - rotation_mode: int, - ) -> None: - """Launch the fused Triton forward kernel for ``Dt_from_m @ x_local``.""" - mode = _require_kernel_mode(rotation_mode) - reduced_dim = coeff_index.numel() - channels = x_local.shape[2] - if mode != TritonRotationMode.GENERIC_TILED: - _launch_rotate_back_small_forward( - rotation_mode=mode, - x_local=x_local, - wigner=wigner, - coeff_index=coeff_index, - out=out, - dim_full=dim_full, - ) - return - wrap_triton(rotate_back_forward_kernel)[ - _generic_rotate_back_forward_grid(dim_full, channels) - ]( - x_local, - wigner, - coeff_index, - out, - x_local.shape[0], - reduced_dim, - dim_full, - channels, - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - BLOCK_REDUCED=TRITON_BLOCK_REDUCED, - BLOCK_FULL=TRITON_BLOCK_FULL, - BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - @triton_op( - "deepmd::_kernel_sezm_rotate_back_bwd_dx", - mutates_args=("grad_x_local",), - ) - def _kernel_sezm_rotate_back_bwd_dx( - grad_out: torch.Tensor, - wigner: torch.Tensor, - coeff_index: torch.Tensor, - grad_x_local: torch.Tensor, - dim_full: int, - rotation_mode: int, - ) -> None: - """Launch the Triton backward kernel for reduced-layout gradients.""" - mode = _require_kernel_mode(rotation_mode) - reduced_dim = coeff_index.numel() - channels = grad_out.shape[2] - if mode != TritonRotationMode.GENERIC_TILED: - _launch_rotate_back_small_bwd_dx( - rotation_mode=mode, - grad_out=grad_out, - wigner=wigner, - coeff_index=coeff_index, - grad_x_local=grad_x_local, - dim_full=dim_full, - ) - return - wrap_triton(rotate_back_bwd_dx_kernel)[ - _generic_rotate_back_bwd_dx_grid(reduced_dim, channels) - ]( - grad_out, - wigner, - coeff_index, - grad_x_local, - grad_out.shape[0], - reduced_dim, - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - wigner.stride(0), - wigner.stride(1), - wigner.stride(2), - grad_x_local.stride(0), - grad_x_local.stride(1), - grad_x_local.stride(2), - BLOCK_REDUCED=TRITON_BLOCK_REDUCED, - BLOCK_FULL=TRITON_BLOCK_FULL, - BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - @triton_op( - "deepmd::_kernel_sezm_rotate_back_bwd_dw", - mutates_args=("grad_wigner",), - ) - def _kernel_sezm_rotate_back_bwd_dw( - grad_out: torch.Tensor, - x_local: torch.Tensor, - coeff_index: torch.Tensor, - grad_wigner: torch.Tensor, - dim_full: int, - rotation_mode: int, - ) -> None: - """Launch the Triton backward kernel for Wigner gradients.""" - mode = _require_kernel_mode(rotation_mode) - reduced_dim = coeff_index.numel() - channels = grad_out.shape[2] - if mode != TritonRotationMode.GENERIC_TILED: - wrap_triton(rotate_back_small_bwd_dw_kernel)[(TRITON_GRID_E_STRIDE,)]( - grad_out, - x_local, - coeff_index, - grad_wigner, - grad_out.shape[0], - x_local.shape[1], - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - grad_wigner.stride(0), - grad_wigner.stride(1), - grad_wigner.stride(2), - BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - return - wrap_triton(rotate_back_bwd_dw_kernel)[ - _generic_rotate_back_bwd_dw_grid(dim_full, reduced_dim) - ]( - grad_out, - x_local, - grad_wigner, - grad_out.shape[0], - x_local.shape[1], - dim_full, - channels, - grad_out.stride(0), - grad_out.stride(1), - grad_out.stride(2), - x_local.stride(0), - x_local.stride(1), - x_local.stride(2), - grad_wigner.stride(0), - grad_wigner.stride(1), - grad_wigner.stride(2), - BLOCK_REDUCED=TRITON_BLOCK_REDUCED, - BLOCK_FULL=TRITON_BLOCK_FULL, - BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, - GRID_E_STRIDE=TRITON_GRID_E_STRIDE, - num_warps=1, - ) - - @triton_op( - "deepmd::_kernel_sezm_edge_geometry_rbf", - mutates_args=("edge_vec", "edge_len", "edge_env", "edge_rbf"), - ) - def _kernel_sezm_edge_geometry_rbf( - coord_flat: torch.Tensor, - center_coord_index: torch.Tensor, - neighbor_coord_index: torch.Tensor, - freqs: torch.Tensor, - edge_vec: torch.Tensor, - edge_len: torch.Tensor, - edge_env: torch.Tensor, - edge_rbf: torch.Tensor, - eps: float, - rcut: float, - edge_env_a: float, - edge_env_b: float, - edge_env_c: float, - edge_env_d: float, - edge_env_exponent: int, - radial_env_a: float, - radial_env_b: float, - radial_env_c: float, - radial_env_d: float, - radial_env_exponent: int, - r_inner: float, - r_outer: float, - has_inner_clamp: bool, - ) -> None: - """Launch the fused edge geometry/RBF forward kernel.""" - wrap_triton(edge_geometry_rbf_forward_kernel)[ - _edge_geometry_rbf_grid(center_coord_index.shape[0], freqs.numel()) - ]( - coord_flat, - center_coord_index, - neighbor_coord_index, - freqs, - edge_vec, - edge_len, - edge_env, - edge_rbf, - center_coord_index.shape[0], - freqs.numel(), - coord_flat.stride(0), - coord_flat.stride(1), - edge_vec.stride(0), - edge_vec.stride(1), - edge_rbf.stride(0), - edge_rbf.stride(1), - eps, - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - r_inner, - r_outer, - EDGE_ENV_EXPONENT=int(edge_env_exponent), - RADIAL_ENV_EXPONENT=int(radial_env_exponent), - HAS_INNER_CLAMP=bool(has_inner_clamp), - BLOCK_EDGE=TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, - BLOCK_RADIAL=TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL, - num_warps=4, - ) - - @triton_op( - "deepmd::_kernel_sezm_edge_geometry_rbf_bwd_accum", - mutates_args=("grad_r_total", "grad_freq"), - ) - def _kernel_sezm_edge_geometry_rbf_bwd_accum( - grad_edge_len: torch.Tensor, - grad_edge_env: torch.Tensor, - grad_edge_rbf: torch.Tensor, - coord_flat: torch.Tensor, - center_coord_index: torch.Tensor, - neighbor_coord_index: torch.Tensor, - freqs: torch.Tensor, - grad_r_total: torch.Tensor, - grad_freq: torch.Tensor, - eps: float, - rcut: float, - edge_env_a: float, - edge_env_b: float, - edge_env_c: float, - edge_env_d: float, - edge_env_exponent: int, - radial_env_a: float, - radial_env_b: float, - radial_env_c: float, - radial_env_d: float, - radial_env_exponent: int, - r_inner: float, - r_outer: float, - has_inner_clamp: bool, - ) -> None: - """Launch the fused edge geometry/RBF accumulation kernel.""" - wrap_triton(edge_geometry_rbf_bwd_accum_kernel)[ - _edge_geometry_rbf_grid(center_coord_index.shape[0], freqs.numel()) - ]( - grad_edge_len, - grad_edge_env, - grad_edge_rbf, - coord_flat, - center_coord_index, - neighbor_coord_index, - freqs, - grad_r_total, - grad_freq, - center_coord_index.shape[0], - freqs.numel(), - coord_flat.stride(0), - coord_flat.stride(1), - grad_edge_rbf.stride(0), - grad_edge_rbf.stride(1), - eps, - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - r_inner, - r_outer, - EDGE_ENV_EXPONENT=int(edge_env_exponent), - RADIAL_ENV_EXPONENT=int(radial_env_exponent), - HAS_INNER_CLAMP=bool(has_inner_clamp), - BLOCK_EDGE=TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, - BLOCK_RADIAL=TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL, - num_warps=4, - ) - - @triton_op( - "deepmd::_kernel_sezm_edge_geometry_rbf_bwd_coord", - mutates_args=("grad_coord",), - ) - def _kernel_sezm_edge_geometry_rbf_bwd_coord( - grad_edge_vec: torch.Tensor, - grad_r_total: torch.Tensor, - coord_flat: torch.Tensor, - center_coord_index: torch.Tensor, - neighbor_coord_index: torch.Tensor, - grad_coord: torch.Tensor, - eps: float, - r_inner: float, - r_outer: float, - has_inner_clamp: bool, - ) -> None: - """Launch the fused edge geometry/RBF coordinate backward kernel.""" - wrap_triton(edge_geometry_rbf_bwd_coord_kernel)[ - _edge_geometry_rbf_coord_grid(center_coord_index.shape[0]) - ]( - grad_edge_vec, - grad_r_total, - coord_flat, - center_coord_index, - neighbor_coord_index, - grad_coord, - center_coord_index.shape[0], - coord_flat.stride(0), - coord_flat.stride(1), - grad_edge_vec.stride(0), - grad_edge_vec.stride(1), - grad_coord.stride(0), - grad_coord.stride(1), - eps, - r_inner, - r_outer, - HAS_INNER_CLAMP=bool(has_inner_clamp), - BLOCK_EDGE=TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, - num_warps=4, - ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py b/deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py deleted file mode 100644 index 5c16c6d759..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py +++ /dev/null @@ -1,134 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -"""Dispatch helpers for SeZM Triton rotation kernels.""" - -from __future__ import ( - annotations, -) - -from typing import ( - Final, -) - -import torch - -from .constants import ( - SEZM_TRITON_AVAILABLE, - TRITON_BLOCK_REDUCED, - TritonRotationMode, -) - -_SMALL_MODE_FROM_DIM: Final[dict[int, TritonRotationMode]] = { - 1: TritonRotationMode.SMALL_LE1, - 4: TritonRotationMode.SMALL_LE1, - 9: TritonRotationMode.SMALL_L2, - 16: TritonRotationMode.SMALL_L3, -} - - -def coerce_rotation_mode( - rotation_mode: int | TritonRotationMode, -) -> TritonRotationMode: - """ - Convert an integer-like dispatch value to ``TritonRotationMode``. - - Parameters - ---------- - rotation_mode - Rotation dispatch value. - - Returns - ------- - TritonRotationMode - Normalized rotation dispatch mode. - """ - if isinstance(rotation_mode, TritonRotationMode): - return rotation_mode - return TritonRotationMode(int(rotation_mode)) - - -def resolve_triton_rotation_mode( - *, - dim_full: int, - reduced_dim: int, -) -> TritonRotationMode: - """ - Resolve the SeZM rotation dispatch mode. - - Parameters - ---------- - dim_full - Full packed SO(3) dimension. - reduced_dim - Truncated m-major coefficient count. - - Returns - ------- - TritonRotationMode - Dispatch mode for the current ``(dim_full, reduced_dim)`` pair. - - Raises - ------ - ValueError - If either dimension is non-positive. - """ - dim_full = int(dim_full) - reduced_dim = int(reduced_dim) - if dim_full <= 0: - raise ValueError("dim_full must be positive") - if reduced_dim <= 0: - raise ValueError("reduced_dim must be positive") - base_mode = _SMALL_MODE_FROM_DIM.get( - dim_full, - TritonRotationMode.GENERIC_TILED, - ) - if ( - base_mode == TritonRotationMode.GENERIC_TILED - and reduced_dim < TRITON_BLOCK_REDUCED - ): - return TritonRotationMode.EAGER_REFERENCE - return base_mode - - -def sezm_triton_enabled( - *, - device: torch.device, - dtype: torch.dtype, -) -> bool: - """ - Return whether SeZM should enable the Triton rotation path. - - Parameters - ---------- - device - Target device for the rotation path. - dtype - Activation dtype for the rotation path. - - Returns - ------- - bool - Whether Triton kernels are available for the given device and dtype. - """ - supported_dtypes = (torch.float16, torch.bfloat16, torch.float32) - return bool( - SEZM_TRITON_AVAILABLE and device.type == "cuda" and dtype in supported_dtypes - ) - - -def uses_triton_kernel( - rotation_mode: int | TritonRotationMode, -) -> bool: - """ - Return whether the dispatch mode launches a Triton kernel. - - Parameters - ---------- - rotation_mode - Rotation dispatch value. - - Returns - ------- - bool - ``True`` when the mode launches a Triton kernel instead of eager fallback. - """ - return coerce_rotation_mode(rotation_mode) != TritonRotationMode.EAGER_REFERENCE diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py b/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py deleted file mode 100644 index b6235173c6..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py +++ /dev/null @@ -1,550 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -# pyright: reportMissingImports=false -# ruff: noqa: ANN001, ANN201, ANN202 -"""Triton kernels for the SeZM edge geometry/RBF chain. - -This file implements the standard non-compile path hot segment: - -``coord_gather -> edge_vec -> edge_len -> inner_clamp -> edge_env -> edge_rbf`` - -The kernels intentionally stop before Wigner-D construction so the existing eager -quaternion/Wigner path remains unchanged. -""" - -from __future__ import ( - annotations, -) - -import triton -import triton.language as tl - - -@triton.jit -def _pow_int(x, power: tl.constexpr): - """Raise ``x`` to a small compile-time integer power.""" - out = x * 0.0 + 1.0 - for _ in tl.static_range(power): - out = out * x - return out - - -@triton.jit -def _safe_sinc_no_pi(x): - """Compute ``sin(x) / x`` with a short Taylor branch near zero.""" - x2 = x * x - approx = 1.0 - x2 / 6.0 + (x2 * x2) / 120.0 - regular = tl.sin(x) / x - return tl.where(tl.abs(x) < 1.0e-4, approx, regular) - - -@triton.jit -def _safe_sinc_grad_no_pi(x): - """Compute ``d/dx [sin(x) / x]`` with a short Taylor branch near zero.""" - x2 = x * x - approx = -x / 3.0 + (x * x2) / 30.0 - regular = (x * tl.cos(x) - tl.sin(x)) / x2 - return tl.where(tl.abs(x) < 1.0e-4, approx, regular) - - -@triton.jit -def _compute_cutoff_envelope( - r, - rcut, - a, - b, - c, - d, - exponent: tl.constexpr, -): - """Evaluate the C^3 cutoff envelope on one distance vector.""" - x = tl.maximum(0.0, tl.minimum(r / rcut, 1.0)) - poly = a + x * (b + x * (c + x * d)) - env = 1.0 + _pow_int(x, exponent) * poly - return tl.where(x < 1.0, env, 0.0) - - -@triton.jit -def _compute_cutoff_envelope_grad( - r, - rcut, - a, - b, - c, - d, - exponent: tl.constexpr, -): - """Evaluate ``d envelope / d r`` on one distance vector.""" - x = tl.maximum(0.0, tl.minimum(r / rcut, 1.0)) - poly = a + x * (b + x * (c + x * d)) - poly_grad = b + 2.0 * c * x + 3.0 * d * x * x - if exponent == 1: - leading = poly - else: - leading = float(exponent) * _pow_int(x, exponent - 1) * poly - grad_x = leading + _pow_int(x, exponent) * poly_grad - return tl.where(x < 1.0, grad_x / rcut, 0.0) - - -@triton.jit -def _apply_inner_clamp( - raw_len, - r_inner, - r_outer, -): - """Apply the septic Hermite inner clamp.""" - delta = r_outer - r_inner - t = tl.maximum(0.0, tl.minimum((raw_len - r_inner) / delta, 1.0)) - t2 = t * t - t4 = t2 * t2 - h = t4 * (20.0 + t * (-45.0 + t * (36.0 - 10.0 * t))) - clamped = r_inner + delta * h - return tl.where(raw_len >= r_outer, raw_len, clamped) - - -@triton.jit -def _apply_inner_clamp_grad( - raw_len, - r_inner, - r_outer, -): - """Evaluate ``d clamp / d raw_len`` for the septic Hermite inner clamp.""" - delta = r_outer - r_inner - t = tl.maximum(0.0, tl.minimum((raw_len - r_inner) / delta, 1.0)) - t2 = t * t - t3 = t2 * t - grad = t3 * (80.0 + t * (-225.0 + t * (216.0 - 70.0 * t))) - return tl.where(raw_len >= r_outer, 1.0, grad) - - -@triton.jit -def edge_geometry_rbf_forward_kernel( - coord_ptr, - center_index_ptr, - neighbor_index_ptr, - freq_ptr, - edge_vec_ptr, - edge_len_ptr, - edge_env_ptr, - edge_rbf_ptr, - num_edges, - n_radial, - coord_stride_n, - coord_stride_c, - edge_vec_stride_e, - edge_vec_stride_c, - edge_rbf_stride_e, - edge_rbf_stride_r, - eps, - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - r_inner, - r_outer, - EDGE_ENV_EXPONENT: tl.constexpr, - RADIAL_ENV_EXPONENT: tl.constexpr, - HAS_INNER_CLAMP: tl.constexpr, - BLOCK_EDGE: tl.constexpr, - BLOCK_RADIAL: tl.constexpr, -): - """Compute the fused edge geometry/RBF chain for one edge/radial tile.""" - pid_edge = tl.program_id(0) - pid_radial = tl.program_id(1) - - edge_offsets = pid_edge * BLOCK_EDGE + tl.arange(0, BLOCK_EDGE) - radial_offsets = pid_radial * BLOCK_RADIAL + tl.arange(0, BLOCK_RADIAL) - edge_mask = edge_offsets < num_edges - radial_mask = radial_offsets < n_radial - first_radial_mask = edge_mask & (pid_radial == 0) - - center_index = tl.load(center_index_ptr + edge_offsets, mask=edge_mask, other=0) - neighbor_index = tl.load(neighbor_index_ptr + edge_offsets, mask=edge_mask, other=0) - - center_x = tl.load( - coord_ptr + center_index * coord_stride_n + 0 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - center_y = tl.load( - coord_ptr + center_index * coord_stride_n + 1 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - center_z = tl.load( - coord_ptr + center_index * coord_stride_n + 2 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - neighbor_x = tl.load( - coord_ptr + neighbor_index * coord_stride_n + 0 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - neighbor_y = tl.load( - coord_ptr + neighbor_index * coord_stride_n + 1 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - neighbor_z = tl.load( - coord_ptr + neighbor_index * coord_stride_n + 2 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - - diff_x = neighbor_x - center_x - diff_y = neighbor_y - center_y - diff_z = neighbor_z - center_z - raw_len = tl.sqrt(diff_x * diff_x + diff_y * diff_y + diff_z * diff_z + eps * eps) - - if HAS_INNER_CLAMP: - clamped_len = _apply_inner_clamp(raw_len, r_inner, r_outer) - scale = clamped_len / raw_len - edge_vec_x = diff_x * scale - edge_vec_y = diff_y * scale - edge_vec_z = diff_z * scale - edge_len = clamped_len - else: - edge_vec_x = diff_x - edge_vec_y = diff_y - edge_vec_z = diff_z - edge_len = raw_len - - edge_env = _compute_cutoff_envelope( - edge_len, - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - exponent=EDGE_ENV_EXPONENT, - ) - radial_env = _compute_cutoff_envelope( - edge_len, - rcut, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - exponent=RADIAL_ENV_EXPONENT, - ) - - tl.store( - edge_vec_ptr + edge_offsets * edge_vec_stride_e + 0 * edge_vec_stride_c, - edge_vec_x, - mask=first_radial_mask, - ) - tl.store( - edge_vec_ptr + edge_offsets * edge_vec_stride_e + 1 * edge_vec_stride_c, - edge_vec_y, - mask=first_radial_mask, - ) - tl.store( - edge_vec_ptr + edge_offsets * edge_vec_stride_e + 2 * edge_vec_stride_c, - edge_vec_z, - mask=first_radial_mask, - ) - tl.store(edge_len_ptr + edge_offsets, edge_len, mask=first_radial_mask) - tl.store(edge_env_ptr + edge_offsets, edge_env, mask=first_radial_mask) - - freqs = tl.load(freq_ptr + radial_offsets, mask=radial_mask, other=0.0) - phase = edge_len[:, None] * freqs[None, :] - raw = freqs[None, :] * _safe_sinc_no_pi(phase) - edge_rbf = raw * radial_env[:, None] - tl.store( - edge_rbf_ptr - + edge_offsets[:, None] * edge_rbf_stride_e - + radial_offsets[None, :] * edge_rbf_stride_r, - edge_rbf, - mask=edge_mask[:, None] & radial_mask[None, :], - ) - - -@triton.jit -def edge_geometry_rbf_bwd_accum_kernel( - grad_edge_len_ptr, - grad_edge_env_ptr, - grad_edge_rbf_ptr, - coord_ptr, - center_index_ptr, - neighbor_index_ptr, - freq_ptr, - grad_r_total_ptr, - grad_freq_ptr, - num_edges, - n_radial, - coord_stride_n, - coord_stride_c, - grad_edge_rbf_stride_e, - grad_edge_rbf_stride_r, - eps, - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - r_inner, - r_outer, - EDGE_ENV_EXPONENT: tl.constexpr, - RADIAL_ENV_EXPONENT: tl.constexpr, - HAS_INNER_CLAMP: tl.constexpr, - BLOCK_EDGE: tl.constexpr, - BLOCK_RADIAL: tl.constexpr, -): - """Accumulate scalar distance gradients and frequency gradients.""" - pid_edge = tl.program_id(0) - pid_radial = tl.program_id(1) - - edge_offsets = pid_edge * BLOCK_EDGE + tl.arange(0, BLOCK_EDGE) - radial_offsets = pid_radial * BLOCK_RADIAL + tl.arange(0, BLOCK_RADIAL) - edge_mask = edge_offsets < num_edges - radial_mask = radial_offsets < n_radial - - center_index = tl.load(center_index_ptr + edge_offsets, mask=edge_mask, other=0) - neighbor_index = tl.load(neighbor_index_ptr + edge_offsets, mask=edge_mask, other=0) - - center_x = tl.load( - coord_ptr + center_index * coord_stride_n + 0 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - center_y = tl.load( - coord_ptr + center_index * coord_stride_n + 1 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - center_z = tl.load( - coord_ptr + center_index * coord_stride_n + 2 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - neighbor_x = tl.load( - coord_ptr + neighbor_index * coord_stride_n + 0 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - neighbor_y = tl.load( - coord_ptr + neighbor_index * coord_stride_n + 1 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - neighbor_z = tl.load( - coord_ptr + neighbor_index * coord_stride_n + 2 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - - diff_x = neighbor_x - center_x - diff_y = neighbor_y - center_y - diff_z = neighbor_z - center_z - raw_len = tl.sqrt(diff_x * diff_x + diff_y * diff_y + diff_z * diff_z + eps * eps) - - if HAS_INNER_CLAMP: - edge_len = _apply_inner_clamp(raw_len, r_inner, r_outer) - else: - edge_len = raw_len - - radial_env = _compute_cutoff_envelope( - edge_len, - rcut, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - exponent=RADIAL_ENV_EXPONENT, - ) - radial_env_grad = _compute_cutoff_envelope_grad( - edge_len, - rcut, - radial_env_a, - radial_env_b, - radial_env_c, - radial_env_d, - exponent=RADIAL_ENV_EXPONENT, - ) - - grad_edge_rbf = tl.load( - grad_edge_rbf_ptr - + edge_offsets[:, None] * grad_edge_rbf_stride_e - + radial_offsets[None, :] * grad_edge_rbf_stride_r, - mask=edge_mask[:, None] & radial_mask[None, :], - other=0.0, - ) - freqs = tl.load(freq_ptr + radial_offsets, mask=radial_mask, other=0.0) - phase = edge_len[:, None] * freqs[None, :] - raw = freqs[None, :] * _safe_sinc_no_pi(phase) - raw_grad_r = freqs[None, :] * freqs[None, :] * _safe_sinc_grad_no_pi(phase) - radial_grad_r = raw_grad_r * radial_env[:, None] + raw * radial_env_grad[:, None] - grad_rbf_to_r = tl.sum(grad_edge_rbf * radial_grad_r, axis=1) - tl.atomic_add(grad_r_total_ptr + edge_offsets, grad_rbf_to_r, mask=edge_mask) - - grad_freq = tl.sum(grad_edge_rbf * (radial_env[:, None] * tl.cos(phase)), axis=0) - tl.atomic_add(grad_freq_ptr + radial_offsets, grad_freq, mask=radial_mask) - - if pid_radial == 0: - grad_edge_len = tl.load( - grad_edge_len_ptr + edge_offsets, mask=edge_mask, other=0.0 - ) - grad_edge_env = tl.load( - grad_edge_env_ptr + edge_offsets, mask=edge_mask, other=0.0 - ) - edge_env_grad = _compute_cutoff_envelope_grad( - edge_len, - rcut, - edge_env_a, - edge_env_b, - edge_env_c, - edge_env_d, - exponent=EDGE_ENV_EXPONENT, - ) - base = grad_edge_len + grad_edge_env * edge_env_grad - tl.atomic_add(grad_r_total_ptr + edge_offsets, base, mask=edge_mask) - - -@triton.jit -def edge_geometry_rbf_bwd_coord_kernel( - grad_edge_vec_ptr, - grad_r_total_ptr, - coord_ptr, - center_index_ptr, - neighbor_index_ptr, - grad_coord_ptr, - num_edges, - coord_stride_n, - coord_stride_c, - grad_edge_vec_stride_e, - grad_edge_vec_stride_c, - grad_coord_stride_n, - grad_coord_stride_c, - eps, - r_inner, - r_outer, - HAS_INNER_CLAMP: tl.constexpr, - BLOCK_EDGE: tl.constexpr, -): - """Backpropagate the fused geometry/RBF chain into flat coordinates.""" - pid_edge = tl.program_id(0) - edge_offsets = pid_edge * BLOCK_EDGE + tl.arange(0, BLOCK_EDGE) - edge_mask = edge_offsets < num_edges - - center_index = tl.load(center_index_ptr + edge_offsets, mask=edge_mask, other=0) - neighbor_index = tl.load(neighbor_index_ptr + edge_offsets, mask=edge_mask, other=0) - - center_x = tl.load( - coord_ptr + center_index * coord_stride_n + 0 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - center_y = tl.load( - coord_ptr + center_index * coord_stride_n + 1 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - center_z = tl.load( - coord_ptr + center_index * coord_stride_n + 2 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - neighbor_x = tl.load( - coord_ptr + neighbor_index * coord_stride_n + 0 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - neighbor_y = tl.load( - coord_ptr + neighbor_index * coord_stride_n + 1 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - neighbor_z = tl.load( - coord_ptr + neighbor_index * coord_stride_n + 2 * coord_stride_c, - mask=edge_mask, - other=0.0, - ) - - diff_x = neighbor_x - center_x - diff_y = neighbor_y - center_y - diff_z = neighbor_z - center_z - raw_len = tl.sqrt(diff_x * diff_x + diff_y * diff_y + diff_z * diff_z + eps * eps) - - if HAS_INNER_CLAMP: - edge_len = _apply_inner_clamp(raw_len, r_inner, r_outer) - clamp_grad = _apply_inner_clamp_grad(raw_len, r_inner, r_outer) - scale = edge_len / raw_len - else: - edge_len = raw_len - clamp_grad = raw_len * 0.0 + 1.0 - scale = raw_len * 0.0 + 1.0 - - grad_edge_vec_x = tl.load( - grad_edge_vec_ptr - + edge_offsets * grad_edge_vec_stride_e - + 0 * grad_edge_vec_stride_c, - mask=edge_mask, - other=0.0, - ) - grad_edge_vec_y = tl.load( - grad_edge_vec_ptr - + edge_offsets * grad_edge_vec_stride_e - + 1 * grad_edge_vec_stride_c, - mask=edge_mask, - other=0.0, - ) - grad_edge_vec_z = tl.load( - grad_edge_vec_ptr - + edge_offsets * grad_edge_vec_stride_e - + 2 * grad_edge_vec_stride_c, - mask=edge_mask, - other=0.0, - ) - grad_r_total = tl.load(grad_r_total_ptr + edge_offsets, mask=edge_mask, other=0.0) - - dot_grad_vec = ( - grad_edge_vec_x * diff_x + grad_edge_vec_y * diff_y + grad_edge_vec_z * diff_z - ) - inv_raw_len = 1.0 / raw_len - scalar = grad_r_total * clamp_grad + dot_grad_vec * ( - (clamp_grad * raw_len - edge_len) * inv_raw_len * inv_raw_len - ) - grad_diff_common = scalar * inv_raw_len - grad_diff_x = grad_edge_vec_x * scale + diff_x * grad_diff_common - grad_diff_y = grad_edge_vec_y * scale + diff_y * grad_diff_common - grad_diff_z = grad_edge_vec_z * scale + diff_z * grad_diff_common - - tl.atomic_add( - grad_coord_ptr + neighbor_index * grad_coord_stride_n + 0 * grad_coord_stride_c, - grad_diff_x, - mask=edge_mask, - ) - tl.atomic_add( - grad_coord_ptr + neighbor_index * grad_coord_stride_n + 1 * grad_coord_stride_c, - grad_diff_y, - mask=edge_mask, - ) - tl.atomic_add( - grad_coord_ptr + neighbor_index * grad_coord_stride_n + 2 * grad_coord_stride_c, - grad_diff_z, - mask=edge_mask, - ) - tl.atomic_add( - grad_coord_ptr + center_index * grad_coord_stride_n + 0 * grad_coord_stride_c, - -grad_diff_x, - mask=edge_mask, - ) - tl.atomic_add( - grad_coord_ptr + center_index * grad_coord_stride_n + 1 * grad_coord_stride_c, - -grad_diff_y, - mask=edge_mask, - ) - tl.atomic_add( - grad_coord_ptr + center_index * grad_coord_stride_n + 2 * grad_coord_stride_c, - -grad_diff_z, - mask=edge_mask, - ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py b/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py deleted file mode 100644 index f2a89abda9..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py +++ /dev/null @@ -1,555 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -# pyright: reportMissingImports=false -# ruff: noqa: ANN001, ANN201 -"""Generic tiled Triton kernels for SeZM SO(2) rotation hot paths. - -This file holds the variable-``lmax`` family used once the packed SO(3) block -no longer fits the small specialized kernels. The tile sizes are fixed on -purpose: ``BLOCK_FULL == BLOCK_REDUCED == 16`` keeps every ``tl.dot`` on a CUDA -shape that Triton accepts, and the kernels below explicitly request -``input_precision="ieee"`` so float32 matches eager PyTorch instead of TF32. -""" - -from __future__ import ( - annotations, -) - -import triton -import triton.language as tl - -# Keep both contraction dimensions at 16 so Triton always sees a legal dot tile. - - -@triton.jit -def rotate_to_local_forward_kernel( - x_ptr, - src_ptr, - wigner_ptr, - coeff_index_ptr, - out_ptr, - num_edges, - reduced_dim, - dim_full, - channels, - x_stride_n, - x_stride_d, - x_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - out_stride_e, - out_stride_r, - out_stride_c, - BLOCK_REDUCED: tl.constexpr, - BLOCK_FULL: tl.constexpr, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute fused row-projected Wigner rotation ``D_to_m @ x[src]``.""" - edge_id = tl.program_id(0) - reduced_block_id = tl.program_id(1) - channel_block_id = tl.program_id(2) - - reduced_offsets = reduced_block_id * BLOCK_REDUCED + tl.arange(0, BLOCK_REDUCED) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - reduced_mask = reduced_offsets < reduced_dim - channel_mask = channel_offsets < channels - - while edge_id < num_edges: - src_idx = tl.load(src_ptr + edge_id).to(tl.int64) - coeff_rows = tl.load( - coeff_index_ptr + reduced_offsets, - mask=reduced_mask, - other=0, - ).to(tl.int64) - acc = tl.zeros((BLOCK_REDUCED, BLOCK_CHANNEL), dtype=tl.float32) - - for full_block in range(0, tl.cdiv(dim_full, BLOCK_FULL)): - full_offsets = full_block * BLOCK_FULL + tl.arange(0, BLOCK_FULL) - full_mask = full_offsets < dim_full - wigner_ptrs = ( - wigner_ptr - + edge_id * wigner_stride_e - + coeff_rows[:, None] * wigner_stride_r - + full_offsets[None, :] * wigner_stride_k - ) - x_ptrs = ( - x_ptr - + src_idx * x_stride_n - + full_offsets[:, None] * x_stride_d - + channel_offsets[None, :] * x_stride_c - ) - w_block = tl.load( - wigner_ptrs, - mask=reduced_mask[:, None] & full_mask[None, :], - other=0.0, - ) - x_block = tl.load( - x_ptrs, - mask=full_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - # Match the eager autocast path: rotate in the activation dtype chosen - # by the current AMP context instead of forcing a higher Wigner dtype. - w_block = w_block.to(x_block.dtype) - acc = tl.dot( - w_block, - x_block, - acc, - input_precision="ieee", - ) - - out_ptrs = ( - out_ptr - + edge_id * out_stride_e - + reduced_offsets[:, None] * out_stride_r - + channel_offsets[None, :] * out_stride_c - ) - tl.store( - out_ptrs, - acc, - mask=reduced_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_to_local_bwd_dx_kernel( - grad_out_ptr, - wigner_ptr, - coeff_index_ptr, - grad_edge_ptr, - num_edges, - reduced_dim, - dim_full, - channels, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - grad_edge_stride_e, - grad_edge_stride_d, - grad_edge_stride_c, - BLOCK_REDUCED: tl.constexpr, - BLOCK_FULL: tl.constexpr, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute per-edge source gradients ``D_to_m^T @ grad`` before scatter.""" - edge_id = tl.program_id(0) - full_block_id = tl.program_id(1) - channel_block_id = tl.program_id(2) - - full_offsets = full_block_id * BLOCK_FULL + tl.arange(0, BLOCK_FULL) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - full_mask = full_offsets < dim_full - channel_mask = channel_offsets < channels - - while edge_id < num_edges: - acc = tl.zeros((BLOCK_FULL, BLOCK_CHANNEL), dtype=tl.float32) - - for reduced_block in range(0, tl.cdiv(reduced_dim, BLOCK_REDUCED)): - reduced_offsets = reduced_block * BLOCK_REDUCED + tl.arange( - 0, BLOCK_REDUCED - ) - reduced_mask = reduced_offsets < reduced_dim - coeff_rows = tl.load( - coeff_index_ptr + reduced_offsets, - mask=reduced_mask, - other=0, - ).to(tl.int64) - wigner_ptrs = ( - wigner_ptr - + edge_id * wigner_stride_e - + coeff_rows[:, None] * wigner_stride_r - + full_offsets[None, :] * wigner_stride_k - ) - grad_ptrs = ( - grad_out_ptr - + edge_id * grad_out_stride_e - + reduced_offsets[:, None] * grad_out_stride_r - + channel_offsets[None, :] * grad_out_stride_c - ) - w_block = tl.load( - wigner_ptrs, - mask=reduced_mask[:, None] & full_mask[None, :], - other=0.0, - ) - grad_block = tl.load( - grad_ptrs, - mask=reduced_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - w_block = w_block.to(grad_block.dtype) - acc = tl.dot( - tl.trans(w_block), - grad_block, - acc, - input_precision="ieee", - ) - - grad_edge_ptrs = ( - grad_edge_ptr - + edge_id * grad_edge_stride_e - + full_offsets[:, None] * grad_edge_stride_d - + channel_offsets[None, :] * grad_edge_stride_c - ) - tl.store( - grad_edge_ptrs, - acc, - mask=full_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_to_local_bwd_dw_kernel( - grad_out_ptr, - x_ptr, - src_ptr, - coeff_index_ptr, - grad_rows_ptr, - num_edges, - reduced_dim, - dim_full, - channels, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - x_stride_n, - x_stride_d, - x_stride_c, - grad_rows_stride_e, - grad_rows_stride_r, - grad_rows_stride_d, - BLOCK_REDUCED: tl.constexpr, - BLOCK_FULL: tl.constexpr, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute row-selected Wigner gradients ``grad @ x[src]^T``.""" - edge_id = tl.program_id(0) - reduced_block_id = tl.program_id(1) - full_block_id = tl.program_id(2) - - reduced_offsets = reduced_block_id * BLOCK_REDUCED + tl.arange(0, BLOCK_REDUCED) - full_offsets = full_block_id * BLOCK_FULL + tl.arange(0, BLOCK_FULL) - reduced_mask = reduced_offsets < reduced_dim - full_mask = full_offsets < dim_full - - while edge_id < num_edges: - src_idx = tl.load(src_ptr + edge_id).to(tl.int64) - acc = tl.zeros((BLOCK_REDUCED, BLOCK_FULL), dtype=tl.float32) - - for channel_block in range(0, tl.cdiv(channels, BLOCK_CHANNEL)): - channel_offsets = channel_block * BLOCK_CHANNEL + tl.arange( - 0, BLOCK_CHANNEL - ) - channel_mask = channel_offsets < channels - grad_ptrs = ( - grad_out_ptr - + edge_id * grad_out_stride_e - + reduced_offsets[:, None] * grad_out_stride_r - + channel_offsets[None, :] * grad_out_stride_c - ) - x_ptrs = ( - x_ptr - + src_idx * x_stride_n - + full_offsets[:, None] * x_stride_d - + channel_offsets[None, :] * x_stride_c - ) - grad_block = tl.load( - grad_ptrs, - mask=reduced_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - x_block = tl.load( - x_ptrs, - mask=full_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - acc = tl.dot( - grad_block, - tl.trans(x_block), - acc, - input_precision="ieee", - ) - - grad_rows_ptrs = ( - grad_rows_ptr - + edge_id * grad_rows_stride_e - + reduced_offsets[:, None] * grad_rows_stride_r - + full_offsets[None, :] * grad_rows_stride_d - ) - tl.store( - grad_rows_ptrs, - acc, - mask=reduced_mask[:, None] & full_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_forward_kernel( - x_local_ptr, - wigner_ptr, - coeff_index_ptr, - out_ptr, - num_edges, - reduced_dim, - dim_full, - channels, - x_local_stride_e, - x_local_stride_r, - x_local_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - out_stride_e, - out_stride_d, - out_stride_c, - BLOCK_REDUCED: tl.constexpr, - BLOCK_FULL: tl.constexpr, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute fused inverse rotation ``Dt_from_m @ x_local``.""" - edge_id = tl.program_id(0) - full_block_id = tl.program_id(1) - channel_block_id = tl.program_id(2) - - full_offsets = full_block_id * BLOCK_FULL + tl.arange(0, BLOCK_FULL) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - full_mask = full_offsets < dim_full - channel_mask = channel_offsets < channels - - while edge_id < num_edges: - acc = tl.zeros((BLOCK_FULL, BLOCK_CHANNEL), dtype=tl.float32) - - for reduced_block in range(0, tl.cdiv(reduced_dim, BLOCK_REDUCED)): - reduced_offsets = reduced_block * BLOCK_REDUCED + tl.arange( - 0, BLOCK_REDUCED - ) - reduced_mask = reduced_offsets < reduced_dim - coeff_cols = tl.load( - coeff_index_ptr + reduced_offsets, - mask=reduced_mask, - other=0, - ).to(tl.int64) - wigner_ptrs = ( - wigner_ptr - + edge_id * wigner_stride_e - + full_offsets[:, None] * wigner_stride_r - + coeff_cols[None, :] * wigner_stride_k - ) - x_ptrs = ( - x_local_ptr - + edge_id * x_local_stride_e - + reduced_offsets[:, None] * x_local_stride_r - + channel_offsets[None, :] * x_local_stride_c - ) - w_block = tl.load( - wigner_ptrs, - mask=full_mask[:, None] & reduced_mask[None, :], - other=0.0, - ) - x_block = tl.load( - x_ptrs, - mask=reduced_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - w_block = w_block.to(x_block.dtype) - acc = tl.dot( - w_block, - x_block, - acc, - input_precision="ieee", - ) - - out_ptrs = ( - out_ptr - + edge_id * out_stride_e - + full_offsets[:, None] * out_stride_d - + channel_offsets[None, :] * out_stride_c - ) - tl.store( - out_ptrs, - acc, - mask=full_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_bwd_dx_kernel( - grad_out_ptr, - wigner_ptr, - coeff_index_ptr, - grad_x_ptr, - num_edges, - reduced_dim, - dim_full, - channels, - grad_out_stride_e, - grad_out_stride_d, - grad_out_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - grad_x_stride_e, - grad_x_stride_r, - grad_x_stride_c, - BLOCK_REDUCED: tl.constexpr, - BLOCK_FULL: tl.constexpr, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute reduced-layout gradients ``Dt_from_m^T @ grad``.""" - edge_id = tl.program_id(0) - reduced_block_id = tl.program_id(1) - channel_block_id = tl.program_id(2) - - reduced_offsets = reduced_block_id * BLOCK_REDUCED + tl.arange(0, BLOCK_REDUCED) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - reduced_mask = reduced_offsets < reduced_dim - channel_mask = channel_offsets < channels - - while edge_id < num_edges: - coeff_cols = tl.load( - coeff_index_ptr + reduced_offsets, - mask=reduced_mask, - other=0, - ).to(tl.int64) - acc = tl.zeros((BLOCK_REDUCED, BLOCK_CHANNEL), dtype=tl.float32) - - for full_block in range(0, tl.cdiv(dim_full, BLOCK_FULL)): - full_offsets = full_block * BLOCK_FULL + tl.arange(0, BLOCK_FULL) - full_mask = full_offsets < dim_full - wigner_ptrs = ( - wigner_ptr - + edge_id * wigner_stride_e - + full_offsets[:, None] * wigner_stride_r - + coeff_cols[None, :] * wigner_stride_k - ) - grad_ptrs = ( - grad_out_ptr - + edge_id * grad_out_stride_e - + full_offsets[:, None] * grad_out_stride_d - + channel_offsets[None, :] * grad_out_stride_c - ) - w_block = tl.load( - wigner_ptrs, - mask=full_mask[:, None] & reduced_mask[None, :], - other=0.0, - ) - grad_block = tl.load( - grad_ptrs, - mask=full_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - w_block = w_block.to(grad_block.dtype) - acc = tl.dot( - tl.trans(w_block), - grad_block, - acc, - input_precision="ieee", - ) - - grad_x_ptrs = ( - grad_x_ptr - + edge_id * grad_x_stride_e - + reduced_offsets[:, None] * grad_x_stride_r - + channel_offsets[None, :] * grad_x_stride_c - ) - tl.store( - grad_x_ptrs, - acc, - mask=reduced_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_bwd_dw_kernel( - grad_out_ptr, - x_local_ptr, - grad_cols_ptr, - num_edges, - reduced_dim, - dim_full, - channels, - grad_out_stride_e, - grad_out_stride_d, - grad_out_stride_c, - x_local_stride_e, - x_local_stride_r, - x_local_stride_c, - grad_cols_stride_e, - grad_cols_stride_d, - grad_cols_stride_r, - BLOCK_REDUCED: tl.constexpr, - BLOCK_FULL: tl.constexpr, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute column-selected inverse Wigner gradients ``grad @ x_local^T``.""" - edge_id = tl.program_id(0) - full_block_id = tl.program_id(1) - reduced_block_id = tl.program_id(2) - - full_offsets = full_block_id * BLOCK_FULL + tl.arange(0, BLOCK_FULL) - reduced_offsets = reduced_block_id * BLOCK_REDUCED + tl.arange(0, BLOCK_REDUCED) - full_mask = full_offsets < dim_full - reduced_mask = reduced_offsets < reduced_dim - - while edge_id < num_edges: - acc = tl.zeros((BLOCK_FULL, BLOCK_REDUCED), dtype=tl.float32) - - for channel_block in range(0, tl.cdiv(channels, BLOCK_CHANNEL)): - channel_offsets = channel_block * BLOCK_CHANNEL + tl.arange( - 0, BLOCK_CHANNEL - ) - channel_mask = channel_offsets < channels - grad_ptrs = ( - grad_out_ptr - + edge_id * grad_out_stride_e - + full_offsets[:, None] * grad_out_stride_d - + channel_offsets[None, :] * grad_out_stride_c - ) - x_ptrs = ( - x_local_ptr - + edge_id * x_local_stride_e - + reduced_offsets[:, None] * x_local_stride_r - + channel_offsets[None, :] * x_local_stride_c - ) - grad_block = tl.load( - grad_ptrs, - mask=full_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - x_block = tl.load( - x_ptrs, - mask=reduced_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - acc = tl.dot( - grad_block, - tl.trans(x_block), - acc, - input_precision="ieee", - ) - - grad_cols_ptrs = ( - grad_cols_ptr - + edge_id * grad_cols_stride_e - + full_offsets[:, None] * grad_cols_stride_d - + reduced_offsets[None, :] * grad_cols_stride_r - ) - tl.store( - grad_cols_ptrs, - acc, - mask=full_mask[:, None] & reduced_mask[None, :], - ) - edge_id += GRID_E_STRIDE diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py b/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py deleted file mode 100644 index 524acfd72f..0000000000 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py +++ /dev/null @@ -1,1317 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -# pyright: reportMissingImports=false -# ruff: noqa: ANN001, ANN201 -"""Specialized small-family Triton kernels for SeZM SO(2) rotations. - -These kernels are the intended fast path for ``lmax <= 3``. They keep one -masked ``16x16`` Wigner tile in registers, so ``lmax=0`` and ``lmax=1`` can -share the same specialized family without paying the loop overhead of the -generic tiled kernels. -""" - -from __future__ import ( - annotations, -) - -import triton -import triton.language as tl - -from .constants import TRITON_SMALL_FULL_DIM as TRITON_SMALL_FULL_DIM_VALUE - -# Small kernels always materialize one padded ``16x16`` block and mask tails. -TRITON_SMALL_FULL_DIM = tl.constexpr(TRITON_SMALL_FULL_DIM_VALUE) - - -@triton.jit -def _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, -) -> tl.tensor: - """Load one padded ``16x16`` Wigner block in l-major order.""" - full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) - full_mask = full_offsets < full_dim - wigner_ptrs = ( - wigner_ptr - + edge_id * wigner_stride_e - + full_offsets[:, None] * wigner_stride_r - + full_offsets[None, :] * wigner_stride_k - ) - return tl.load( - wigner_ptrs, - mask=full_mask[:, None] & full_mask[None, :], - other=0.0, - ) - - -@triton.jit -def _load_full_node_values( - x_ptr, - node_idx, - full_dim, - channel_offsets, - channel_mask, - x_stride_n, - x_stride_d, - x_stride_c, -) -> tl.tensor: - """Load one padded ``16xC`` node feature block in l-major order.""" - full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) - full_mask = full_offsets < full_dim - x_ptrs = ( - x_ptr - + node_idx * x_stride_n - + full_offsets[:, None] * x_stride_d - + channel_offsets[None, :] * x_stride_c - ) - return tl.load( - x_ptrs, - mask=full_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - - -@triton.jit -def _load_reduced_values_with_index( - x_ptr, - coeff_index_ptr, - edge_id, - reduced_dim, - channel_offsets, - channel_mask, - x_stride_e, - x_stride_r, - x_stride_c, -) -> tuple[tl.tensor, tl.tensor, tl.tensor]: - """Load reduced values together with the padded reduced->full row mapping.""" - reduced_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) - reduced_mask = reduced_offsets < reduced_dim - x_ptrs = ( - x_ptr - + edge_id * x_stride_e - + reduced_offsets[:, None] * x_stride_r - + channel_offsets[None, :] * x_stride_c - ) - reduced_values = tl.load( - x_ptrs, - mask=reduced_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - coeff_rows = tl.load( - coeff_index_ptr + reduced_offsets, - mask=reduced_mask, - other=-1, - ).to(tl.int64) - return reduced_values, reduced_mask, coeff_rows - - -@triton.jit -def _scatter_reduced_to_full_matrix( - reduced_values, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL: tl.constexpr, -) -> tl.tensor: - """Scatter a padded reduced block into a padded full l-major block.""" - row_ids = tl.arange(0, TRITON_SMALL_FULL_DIM) - full_values = tl.zeros( - (TRITON_SMALL_FULL_DIM, BLOCK_CHANNEL), - dtype=reduced_values.dtype, - ) - for row in range(TRITON_SMALL_FULL_DIM): - row_mask = (coeff_rows == row)[:, None] & reduced_mask[:, None] - row_value = tl.sum(tl.where(row_mask, reduced_values, 0.0), axis=0).to( - reduced_values.dtype - ) - full_values = tl.where( - row_ids[:, None] == row, - row_value[None, :], - full_values, - ) - return full_values - - -@triton.jit -def _select_reduced_from_full_matrix( - full_values, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL: tl.constexpr, -) -> tl.tensor: - """Select reduced rows from a padded full l-major block.""" - row_ids = tl.arange(0, TRITON_SMALL_FULL_DIM) - reduced_values = tl.zeros( - (TRITON_SMALL_FULL_DIM, BLOCK_CHANNEL), - dtype=full_values.dtype, - ) - for row in range(TRITON_SMALL_FULL_DIM): - row_value = tl.sum( - tl.where(row_ids[:, None] == row, full_values, 0.0), - axis=0, - ).to(full_values.dtype) - reduced_values = tl.where( - (coeff_rows == row)[:, None] & reduced_mask[:, None], - row_value[None, :], - reduced_values, - ) - return reduced_values - - -@triton.jit -def _build_full_matrix_l1( - y0, - y1, - y2, - y3, - BLOCK_CHANNEL: tl.constexpr, -) -> tl.tensor: - """Build a padded full matrix from the ``lmax=1`` row vectors.""" - row_ids = tl.arange(0, TRITON_SMALL_FULL_DIM) - full_values = tl.zeros( - (TRITON_SMALL_FULL_DIM, BLOCK_CHANNEL), - dtype=tl.float32, - ) - full_values = tl.where(row_ids[:, None] == 0, y0[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 1, y1[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 2, y2[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 3, y3[None, :], full_values) - return full_values - - -@triton.jit -def _build_full_matrix_l2( - y0, - y1, - y2, - y3, - y4, - y5, - y6, - y7, - y8, - BLOCK_CHANNEL: tl.constexpr, -) -> tl.tensor: - """Build a padded full matrix from the ``lmax=2`` row vectors.""" - row_ids = tl.arange(0, TRITON_SMALL_FULL_DIM) - full_values = tl.zeros( - (TRITON_SMALL_FULL_DIM, BLOCK_CHANNEL), - dtype=tl.float32, - ) - full_values = tl.where(row_ids[:, None] == 0, y0[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 1, y1[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 2, y2[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 3, y3[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 4, y4[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 5, y5[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 6, y6[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 7, y7[None, :], full_values) - full_values = tl.where(row_ids[:, None] == 8, y8[None, :], full_values) - return full_values - - -@triton.jit -def _matvec_l1( - w_full, - x_full, -) -> tl.tensor: - """Apply the packed ``lmax=1`` block-diagonal Wigner matrix.""" - return tl.dot(w_full.to(x_full.dtype), x_full, input_precision="ieee") - - -@triton.jit -def _matvec_t_l1( - w_full, - x_full, -) -> tl.tensor: - """Apply the transpose of the packed ``lmax=1`` Wigner matrix.""" - return tl.dot( - tl.trans(w_full.to(x_full.dtype)), - x_full, - input_precision="ieee", - ) - - -@triton.jit -def _matvec_l2( - w_full, - x_full, -) -> tl.tensor: - """Apply the packed ``lmax=2`` block-diagonal Wigner matrix.""" - return tl.dot(w_full.to(x_full.dtype), x_full, input_precision="ieee") - - -@triton.jit -def _matvec_t_l2( - w_full, - x_full, -) -> tl.tensor: - """Apply the transpose of the packed ``lmax=2`` Wigner matrix.""" - return tl.dot( - tl.trans(w_full.to(x_full.dtype)), - x_full, - input_precision="ieee", - ) - - -@triton.jit -def rotate_to_local_l1_forward_kernel( - x_ptr, - src_ptr, - wigner_ptr, - coeff_index_ptr, - out_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - x_stride_n, - x_stride_d, - x_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - out_stride_e, - out_stride_r, - out_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Fused ``global -> local reduced`` rotation specialized for ``lmax=1``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - src_idx = tl.load(src_ptr + edge_id).to(tl.int64) - coeff_rows = tl.load( - coeff_index_ptr + tl.arange(0, TRITON_SMALL_FULL_DIM), - mask=tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim, - other=-1, - ).to(tl.int64) - reduced_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim - x_full = _load_full_node_values( - x_ptr, - src_idx, - full_dim, - channel_offsets, - channel_mask, - x_stride_n, - x_stride_d, - x_stride_c, - ) - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(x_full.dtype) - y_full = _matvec_l1(w_full, x_full) - out_values = _select_reduced_from_full_matrix( - y_full, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - out_ptrs = ( - out_ptr - + edge_id * out_stride_e - + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_r - + channel_offsets[None, :] * out_stride_c - ) - tl.store( - out_ptrs, - out_values, - mask=reduced_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_to_local_l2_forward_kernel( - x_ptr, - src_ptr, - wigner_ptr, - coeff_index_ptr, - out_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - x_stride_n, - x_stride_d, - x_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - out_stride_e, - out_stride_r, - out_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Fused ``global -> local reduced`` rotation specialized for ``lmax=2``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - src_idx = tl.load(src_ptr + edge_id).to(tl.int64) - coeff_rows = tl.load( - coeff_index_ptr + tl.arange(0, TRITON_SMALL_FULL_DIM), - mask=tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim, - other=-1, - ).to(tl.int64) - reduced_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim - x_full = _load_full_node_values( - x_ptr, - src_idx, - full_dim, - channel_offsets, - channel_mask, - x_stride_n, - x_stride_d, - x_stride_c, - ) - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(x_full.dtype) - y_full = _matvec_l2(w_full, x_full) - out_values = _select_reduced_from_full_matrix( - y_full, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - out_ptrs = ( - out_ptr - + edge_id * out_stride_e - + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_r - + channel_offsets[None, :] * out_stride_c - ) - tl.store( - out_ptrs, - out_values, - mask=reduced_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_to_local_l3_forward_kernel( - x_ptr, - src_ptr, - wigner_ptr, - coeff_index_ptr, - out_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - x_stride_n, - x_stride_d, - x_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - out_stride_e, - out_stride_r, - out_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Fused ``global -> local reduced`` rotation specialized for ``lmax=3``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - src_idx = tl.load(src_ptr + edge_id).to(tl.int64) - coeff_rows = tl.load( - coeff_index_ptr + tl.arange(0, TRITON_SMALL_FULL_DIM), - mask=tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim, - other=-1, - ).to(tl.int64) - reduced_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim - x_full = _load_full_node_values( - x_ptr, - src_idx, - full_dim, - channel_offsets, - channel_mask, - x_stride_n, - x_stride_d, - x_stride_c, - ) - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(x_full.dtype) - y_full = tl.dot(w_full, x_full, input_precision="ieee") - out_values = _select_reduced_from_full_matrix( - y_full, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - out_ptrs = ( - out_ptr - + edge_id * out_stride_e - + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_r - + channel_offsets[None, :] * out_stride_c - ) - tl.store( - out_ptrs, - out_values, - mask=reduced_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_to_local_l1_bwd_dx_kernel( - grad_out_ptr, - wigner_ptr, - coeff_index_ptr, - grad_edge_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - grad_edge_stride_e, - grad_edge_stride_d, - grad_edge_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute per-edge source gradients specialized for ``lmax=1``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - grad_reduced, reduced_mask, coeff_rows = _load_reduced_values_with_index( - grad_out_ptr, - coeff_index_ptr, - edge_id, - reduced_dim, - channel_offsets, - channel_mask, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - ) - grad_full = _scatter_reduced_to_full_matrix( - grad_reduced, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(grad_full.dtype) - dx_full = _matvec_t_l1(w_full, grad_full) - full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim - grad_edge_ptrs = ( - grad_edge_ptr - + edge_id * grad_edge_stride_e - + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * grad_edge_stride_d - + channel_offsets[None, :] * grad_edge_stride_c - ) - tl.store( - grad_edge_ptrs, - dx_full, - mask=full_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_to_local_l2_bwd_dx_kernel( - grad_out_ptr, - wigner_ptr, - coeff_index_ptr, - grad_edge_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - grad_edge_stride_e, - grad_edge_stride_d, - grad_edge_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute per-edge source gradients specialized for ``lmax=2``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - grad_reduced, reduced_mask, coeff_rows = _load_reduced_values_with_index( - grad_out_ptr, - coeff_index_ptr, - edge_id, - reduced_dim, - channel_offsets, - channel_mask, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - ) - grad_full = _scatter_reduced_to_full_matrix( - grad_reduced, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(grad_full.dtype) - dx_full = _matvec_t_l2(w_full, grad_full) - full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim - grad_edge_ptrs = ( - grad_edge_ptr - + edge_id * grad_edge_stride_e - + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * grad_edge_stride_d - + channel_offsets[None, :] * grad_edge_stride_c - ) - tl.store( - grad_edge_ptrs, - dx_full, - mask=full_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_to_local_l3_bwd_dx_kernel( - grad_out_ptr, - wigner_ptr, - coeff_index_ptr, - grad_edge_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - grad_edge_stride_e, - grad_edge_stride_d, - grad_edge_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute per-edge source gradients specialized for ``lmax=3``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - grad_reduced, reduced_mask, coeff_rows = _load_reduced_values_with_index( - grad_out_ptr, - coeff_index_ptr, - edge_id, - reduced_dim, - channel_offsets, - channel_mask, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - ) - grad_full = _scatter_reduced_to_full_matrix( - grad_reduced, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(grad_full.dtype) - dx_full = tl.dot( - tl.trans(w_full), - grad_full, - input_precision="ieee", - ) - full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim - grad_edge_ptrs = ( - grad_edge_ptr - + edge_id * grad_edge_stride_e - + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * grad_edge_stride_d - + channel_offsets[None, :] * grad_edge_stride_c - ) - tl.store( - grad_edge_ptrs, - dx_full, - mask=full_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_to_local_small_bwd_dw_kernel( - grad_out_ptr, - x_ptr, - src_ptr, - coeff_index_ptr, - grad_wigner_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - x_stride_n, - x_stride_d, - x_stride_c, - grad_wigner_stride_e, - grad_wigner_stride_r, - grad_wigner_stride_k, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute full Wigner gradients for specialized small-l rotate-to-local.""" - edge_id = tl.program_id(0) - channel_offsets = tl.arange(0, BLOCK_CHANNEL) - full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) - while edge_id < num_edges: - coeff_rows = tl.load( - coeff_index_ptr + full_offsets, - mask=full_offsets < reduced_dim, - other=-1, - ).to(tl.int64) - reduced_mask = full_offsets < reduced_dim - src_idx = tl.load(src_ptr + edge_id).to(tl.int64) - grad_w_acc = tl.zeros( - (TRITON_SMALL_FULL_DIM, TRITON_SMALL_FULL_DIM), - dtype=tl.float32, - ) - channel_start = 0 - while channel_start < channels: - block_offsets = channel_start + channel_offsets - channel_mask = block_offsets < channels - grad_reduced, _, _ = _load_reduced_values_with_index( - grad_out_ptr, - coeff_index_ptr, - edge_id, - reduced_dim, - block_offsets, - channel_mask, - grad_out_stride_e, - grad_out_stride_r, - grad_out_stride_c, - ) - grad_full_block = _scatter_reduced_to_full_matrix( - grad_reduced, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - x_full_block = _load_full_node_values( - x_ptr, - src_idx, - full_dim, - block_offsets, - channel_mask, - x_stride_n, - x_stride_d, - x_stride_c, - ) - grad_w_acc += tl.dot( - grad_full_block, - tl.trans(x_full_block.to(grad_full_block.dtype)), - input_precision="ieee", - ) - channel_start += BLOCK_CHANNEL - grad_w_ptrs = ( - grad_wigner_ptr - + edge_id * grad_wigner_stride_e - + full_offsets[:, None] * grad_wigner_stride_r - + full_offsets[None, :] * grad_wigner_stride_k - ) - full_mask = full_offsets < full_dim - tl.store( - grad_w_ptrs, - grad_w_acc, - mask=full_mask[:, None] & full_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_l1_forward_kernel( - x_ptr, - wigner_ptr, - coeff_index_ptr, - out_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - x_stride_e, - x_stride_r, - x_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - out_stride_e, - out_stride_d, - out_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Fused ``local reduced -> global`` rotation specialized for ``lmax=1``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - reduced_values, reduced_mask, coeff_rows = _load_reduced_values_with_index( - x_ptr, - coeff_index_ptr, - edge_id, - reduced_dim, - channel_offsets, - channel_mask, - x_stride_e, - x_stride_r, - x_stride_c, - ) - x_full = _scatter_reduced_to_full_matrix( - reduced_values, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(x_full.dtype) - y_full = _matvec_l1(w_full, x_full) - full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim - out_ptrs = ( - out_ptr - + edge_id * out_stride_e - + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_d - + channel_offsets[None, :] * out_stride_c - ) - tl.store( - out_ptrs, - y_full, - mask=full_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_l2_forward_kernel( - x_ptr, - wigner_ptr, - coeff_index_ptr, - out_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - x_stride_e, - x_stride_r, - x_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - out_stride_e, - out_stride_d, - out_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Fused ``local reduced -> global`` rotation specialized for ``lmax=2``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - reduced_values, reduced_mask, coeff_rows = _load_reduced_values_with_index( - x_ptr, - coeff_index_ptr, - edge_id, - reduced_dim, - channel_offsets, - channel_mask, - x_stride_e, - x_stride_r, - x_stride_c, - ) - x_full = _scatter_reduced_to_full_matrix( - reduced_values, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(x_full.dtype) - y_full = _matvec_l2(w_full, x_full) - full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim - out_ptrs = ( - out_ptr - + edge_id * out_stride_e - + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_d - + channel_offsets[None, :] * out_stride_c - ) - tl.store( - out_ptrs, - y_full, - mask=full_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_l3_forward_kernel( - x_ptr, - wigner_ptr, - coeff_index_ptr, - out_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - x_stride_e, - x_stride_r, - x_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - out_stride_e, - out_stride_d, - out_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Fused ``local reduced -> global`` rotation specialized for ``lmax=3``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - reduced_values, reduced_mask, coeff_rows = _load_reduced_values_with_index( - x_ptr, - coeff_index_ptr, - edge_id, - reduced_dim, - channel_offsets, - channel_mask, - x_stride_e, - x_stride_r, - x_stride_c, - ) - x_full = _scatter_reduced_to_full_matrix( - reduced_values, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(x_full.dtype) - y_full = tl.dot( - w_full.to(x_full.dtype), - x_full, - input_precision="ieee", - ) - full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim - out_ptrs = ( - out_ptr - + edge_id * out_stride_e - + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_d - + channel_offsets[None, :] * out_stride_c - ) - tl.store( - out_ptrs, - y_full, - mask=full_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_l1_bwd_dx_kernel( - grad_out_ptr, - wigner_ptr, - coeff_index_ptr, - grad_x_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - grad_out_stride_e, - grad_out_stride_d, - grad_out_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - grad_x_stride_e, - grad_x_stride_r, - grad_x_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute reduced-layout gradients specialized for ``lmax=1``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) - full_mask = full_offsets < full_dim - grad_ptrs = ( - grad_out_ptr - + edge_id * grad_out_stride_e - + full_offsets[:, None] * grad_out_stride_d - + channel_offsets[None, :] * grad_out_stride_c - ) - grad_full = tl.load( - grad_ptrs, - mask=full_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - coeff_rows = tl.load( - coeff_index_ptr + full_offsets, - mask=full_offsets < reduced_dim, - other=-1, - ).to(tl.int64) - reduced_mask = full_offsets < reduced_dim - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(grad_full.dtype) - dx_full = _matvec_t_l1(w_full, grad_full) - grad_reduced = _select_reduced_from_full_matrix( - dx_full, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - grad_x_ptrs = ( - grad_x_ptr - + edge_id * grad_x_stride_e - + full_offsets[:, None] * grad_x_stride_r - + channel_offsets[None, :] * grad_x_stride_c - ) - tl.store( - grad_x_ptrs, - grad_reduced, - mask=reduced_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_l2_bwd_dx_kernel( - grad_out_ptr, - wigner_ptr, - coeff_index_ptr, - grad_x_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - grad_out_stride_e, - grad_out_stride_d, - grad_out_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - grad_x_stride_e, - grad_x_stride_r, - grad_x_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute reduced-layout gradients specialized for ``lmax=2``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) - full_mask = full_offsets < full_dim - grad_ptrs = ( - grad_out_ptr - + edge_id * grad_out_stride_e - + full_offsets[:, None] * grad_out_stride_d - + channel_offsets[None, :] * grad_out_stride_c - ) - grad_full = tl.load( - grad_ptrs, - mask=full_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - coeff_rows = tl.load( - coeff_index_ptr + full_offsets, - mask=full_offsets < reduced_dim, - other=-1, - ).to(tl.int64) - reduced_mask = full_offsets < reduced_dim - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(grad_full.dtype) - dx_full = _matvec_t_l2(w_full, grad_full) - grad_reduced = _select_reduced_from_full_matrix( - dx_full, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - grad_x_ptrs = ( - grad_x_ptr - + edge_id * grad_x_stride_e - + full_offsets[:, None] * grad_x_stride_r - + channel_offsets[None, :] * grad_x_stride_c - ) - tl.store( - grad_x_ptrs, - grad_reduced, - mask=reduced_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_l3_bwd_dx_kernel( - grad_out_ptr, - wigner_ptr, - coeff_index_ptr, - grad_x_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - grad_out_stride_e, - grad_out_stride_d, - grad_out_stride_c, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - grad_x_stride_e, - grad_x_stride_r, - grad_x_stride_c, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute reduced-layout gradients specialized for ``lmax=3``.""" - edge_id = tl.program_id(0) - channel_block_id = tl.program_id(1) - channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) - channel_mask = channel_offsets < channels - while edge_id < num_edges: - full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) - full_mask = full_offsets < full_dim - grad_ptrs = ( - grad_out_ptr - + edge_id * grad_out_stride_e - + full_offsets[:, None] * grad_out_stride_d - + channel_offsets[None, :] * grad_out_stride_c - ) - grad_full = tl.load( - grad_ptrs, - mask=full_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - coeff_rows = tl.load( - coeff_index_ptr + full_offsets, - mask=full_offsets < reduced_dim, - other=-1, - ).to(tl.int64) - reduced_mask = full_offsets < reduced_dim - w_full = _load_full_wigner_matrix( - wigner_ptr, - edge_id, - full_dim, - wigner_stride_e, - wigner_stride_r, - wigner_stride_k, - ).to(grad_full.dtype) - dx_full = tl.dot( - tl.trans(w_full.to(grad_full.dtype)), - grad_full, - input_precision="ieee", - ) - grad_reduced = _select_reduced_from_full_matrix( - dx_full, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - grad_x_ptrs = ( - grad_x_ptr - + edge_id * grad_x_stride_e - + full_offsets[:, None] * grad_x_stride_r - + channel_offsets[None, :] * grad_x_stride_c - ) - tl.store( - grad_x_ptrs, - grad_reduced, - mask=reduced_mask[:, None] & channel_mask[None, :], - ) - edge_id += GRID_E_STRIDE - - -@triton.jit -def rotate_back_small_bwd_dw_kernel( - grad_out_ptr, - x_ptr, - coeff_index_ptr, - grad_wigner_ptr, - num_edges, - reduced_dim, - full_dim, - channels, - grad_out_stride_e, - grad_out_stride_d, - grad_out_stride_c, - x_stride_e, - x_stride_r, - x_stride_c, - grad_wigner_stride_e, - grad_wigner_stride_r, - grad_wigner_stride_k, - BLOCK_CHANNEL: tl.constexpr, - GRID_E_STRIDE: tl.constexpr, -): - """Compute full Wigner gradients for specialized small-l rotate-back.""" - edge_id = tl.program_id(0) - channel_offsets = tl.arange(0, BLOCK_CHANNEL) - full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) - while edge_id < num_edges: - coeff_rows = tl.load( - coeff_index_ptr + full_offsets, - mask=full_offsets < reduced_dim, - other=-1, - ).to(tl.int64) - reduced_mask = full_offsets < reduced_dim - grad_w_acc = tl.zeros( - (TRITON_SMALL_FULL_DIM, TRITON_SMALL_FULL_DIM), - dtype=tl.float32, - ) - channel_start = 0 - while channel_start < channels: - block_offsets = channel_start + channel_offsets - channel_mask = block_offsets < channels - full_mask = full_offsets < full_dim - grad_ptrs = ( - grad_out_ptr - + edge_id * grad_out_stride_e - + full_offsets[:, None] * grad_out_stride_d - + block_offsets[None, :] * grad_out_stride_c - ) - grad_full = tl.load( - grad_ptrs, - mask=full_mask[:, None] & channel_mask[None, :], - other=0.0, - ) - reduced_values, _, _ = _load_reduced_values_with_index( - x_ptr, - coeff_index_ptr, - edge_id, - reduced_dim, - block_offsets, - channel_mask, - x_stride_e, - x_stride_r, - x_stride_c, - ) - x_full = _scatter_reduced_to_full_matrix( - reduced_values, - reduced_mask, - coeff_rows, - BLOCK_CHANNEL=BLOCK_CHANNEL, - ) - grad_w_acc += tl.dot( - grad_full, - tl.trans(x_full.to(grad_full.dtype)), - input_precision="ieee", - ) - channel_start += BLOCK_CHANNEL - grad_w_ptrs = ( - grad_wigner_ptr - + edge_id * grad_wigner_stride_e - + full_offsets[:, None] * grad_wigner_stride_r - + full_offsets[None, :] * grad_wigner_stride_k - ) - full_mask = full_offsets < full_dim - tl.store( - grad_w_ptrs, - grad_w_acc, - mask=full_mask[:, None] & full_mask[None, :], - ) - edge_id += GRID_E_STRIDE diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py b/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py new file mode 100644 index 0000000000..258159cec0 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py @@ -0,0 +1,1715 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN202 +"""Fused Triton SO(2)/Wigner rotation operators for the SeZM/DPA4 descriptor. + +This module provides a *clean room* Triton implementation of the two rotation +hot paths used by the SeZM SO(2) convolution: + +``rotate_to_local`` (global -> edge-local reduced frame) + For every edge ``e`` with source node ``src[e]``:: + + out[e] = Wrows[e] @ x[src[e]] # (Dm, C) + Wrows[e][m, k] = wigner[e, coeff_index[m], k] # (Dm, D), k < D + + i.e. the eager reference ``bmm(D_to_m, x[src])`` where + ``D_to_m = wigner[:, :D, :D].index_select(1, coeff_index)``. + +``rotate_back`` (edge-local reduced frame -> global) + For every edge ``e``:: + + out[e] = Wcols[e] @ x_local[e] # (D, C) + Wcols[e][d, m] = wigner[e, d, coeff_index[m]] # (D, Dm), d < D + + i.e. the eager reference ``bmm(Dt_from_m, x_local)`` where + ``Dt_from_m = wigner[:, :D, :D].index_select(2, coeff_index)``. + +Design goals +------------ +1. **Fuse the gathers into the GEMM.** The eager / ``torch.compile`` path first + materializes ``D_to_m`` (or ``Dt_from_m``), shape ``(E, Dm, D)``, *and* + ``x[src]``, shape ``(E, D, C)``, before calling ``bmm``. For lmax 10 with + E=100k that is ~9 GB of scratch that is written and immediately re-read. + We instead gather the Wigner rows/columns (by ``coeff_index``) and the node + features (by ``src``) *inside* the kernel, so neither scratch tensor is ever + created. Each edge is one tiny GEMM; this also sidesteps the well-known + inefficiency of cuBLAS strided-batched GEMM on very small matrices. + +2. **Match eager FP32 accuracy.** Every ``tl.dot`` uses + ``input_precision="ieee"`` so the contraction runs in true IEEE FP32 (no + TF32). This keeps the potential-energy surface smooth. + +3. **Compose with ``torch.compile`` (correct forces).** The public ops are + *modern* functional ``torch.library.custom_op`` s (``mutates_args=()``) with + ``register_fake`` + ``register_autograd``. The backward is itself a pair of + functional custom ops. We never use ``torch.autograd.Function`` and never + mutate an input/output tensor in place at the Python level. This is what + makes the gradient w.r.t. ``wigner`` survive ``make_fx`` functionalization + (the legacy ``autograd.Function`` + in-place path drops it, producing wrong + forces under ``torch.compile``). + +Shapes / dtypes +--------------- +``x``/``x_local`` and ``wigner`` are float (fp32 is the supported precision for +the smooth PES; fp16/bf16 also run but accumulate in fp32). ``src`` and +``coeff_index`` are int64. ``E`` (edges) may exceed 2**31 elements once +multiplied by the per-edge matrix size, so all kernels use int64 addressing. +""" + +from __future__ import ( + annotations, +) + +import math + +import torch +from torch import ( + Tensor, +) + +from ..indexing import ( + build_m_major_index, +) + +__all__ = [ + "TRITON_ROTATION_AVAILABLE", + "rotate_back", + "rotate_back_block", + "rotate_back_dense", + "rotate_back_reference", + "rotate_to_local", + "rotate_to_local_block", + "rotate_to_local_dense", + "rotate_to_local_reference", +] + +try: + import triton + import triton.language as tl + + TRITON_ROTATION_AVAILABLE = True +except ImportError: # pragma: no cover - exercised only without triton + TRITON_ROTATION_AVAILABLE = False + + +# ====================================================================== +# Eager reference / fallback implementations +# ====================================================================== +def rotate_to_local_reference( + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + """Eager ground-truth for ``rotate_to_local`` (``bmm(D_to_m, x[src])``).""" + d_to_m = wigner[:, :dim_full, :dim_full].index_select(1, coeff_index) + return torch.bmm(d_to_m, x.index_select(0, src)) + + +def rotate_back_reference( + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + """Eager ground-truth for ``rotate_back`` (``bmm(Dt_from_m, x_local)``).""" + dt_from_m = wigner[:, :dim_full, :dim_full].index_select(2, coeff_index) + return torch.bmm(dt_from_m, x_local) + + +def _rotate_to_local_bwd_eager( + grad_out: Tensor, + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> tuple[Tensor, Tensor]: + """Eager backward of ``rotate_to_local`` returning ``(grad_x, grad_wigner)``.""" + w_rows = wigner[:, :dim_full, :dim_full].index_select(1, coeff_index) # (E,Dm,D) + x_src = x.index_select(0, src) # (E,D,C) + grad_x_src = torch.bmm(w_rows.transpose(1, 2), grad_out) # (E,D,C) + grad_x = torch.zeros_like(x).index_add_(0, src, grad_x_src) + grad_rows = torch.bmm(grad_out, x_src.transpose(1, 2)) # (E,Dm,D) + grad_block = torch.zeros( + grad_out.shape[0], dim_full, dim_full, dtype=wigner.dtype, device=wigner.device + ) + grad_block.index_copy_(1, coeff_index, grad_rows) + grad_wigner = torch.zeros_like(wigner) + grad_wigner[:, :dim_full, :dim_full] = grad_block + return grad_x, grad_wigner + + +def _rotate_back_bwd_eager( + grad_out: Tensor, + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> tuple[Tensor, Tensor]: + """Eager backward of ``rotate_back`` returning ``(grad_x_local, grad_wigner)``.""" + w_cols = wigner[:, :dim_full, :dim_full].index_select(2, coeff_index) # (E,D,Dm) + grad_x_local = torch.bmm(w_cols.transpose(1, 2), grad_out) # (E,Dm,C) + grad_cols = torch.bmm(grad_out, x_local.transpose(1, 2)) # (E,D,Dm) + grad_block = torch.zeros( + grad_out.shape[0], dim_full, dim_full, dtype=wigner.dtype, device=wigner.device + ) + grad_block.index_copy_(2, coeff_index, grad_cols) + grad_wigner = torch.zeros_like(wigner) + grad_wigner[:, :dim_full, :dim_full] = grad_block + return grad_x_local, grad_wigner + + +# ====================================================================== +# Tile-size helpers and autotuning configs +# ====================================================================== +def _tile_dim(value: int) -> int: + """Pick a single-tile edge: the next power of two, at least 16. + + Tiles spanning a whole dimension (the non-tiled ``N`` axis and the static + ``BLOCK_N``) must be a power of two (``tl.arange``) *and* a multiple of 16 + (``tl.dot``); powers of two ``>= 16`` satisfy both. Packed dims map as + ``16 -> 16`` (lmax 3), ``36 -> 64`` (lmax 5), ``64 -> 64`` (lmax 7), + ``121 -> 128`` (lmax 10), ``C=64 -> 64``. + """ + tile = 16 + target = max(int(value), 1) + while tile < target: + tile *= 2 + return tile + + +def _autotune_configs() -> list: + """A small curated set of (BLOCK_M, BLOCK_K, num_warps, num_stages) configs. + + The per-edge GEMMs are tiny (M, K, N <= 128). We tile the output-row axis + ``M`` across the grid and stream the contraction axis ``K`` in a pipelined + loop, so the dominant Wigner load overlaps with the matmul. Autotuning over + a handful of shapes lets one source kernel serve lmax 3..10 well (small + tiles for lmax 3, larger tiles / more warps for lmax 10). + """ + return [ + # Tiny tiles: best for lmax 3 (D=16), where a single 16x16 row tile and a + # one-shot K step behave like a per-edge matvec with minimal overhead. + triton.Config({"BLOCK_M": 16, "BLOCK_K": 16}, num_warps=1, num_stages=2), + triton.Config({"BLOCK_M": 16, "BLOCK_K": 16}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_K": 16}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_M": 16, "BLOCK_K": 64}, num_warps=2, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_K": 32}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 32, "BLOCK_K": 64}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_K": 32}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_K": 64}, num_warps=4, num_stages=4), + triton.Config({"BLOCK_M": 64, "BLOCK_K": 64}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 128, "BLOCK_K": 32}, num_warps=8, num_stages=3), + ] + + +if TRITON_ROTATION_AVAILABLE: + _CONFIGS = _autotune_configs() + _KEY = ["dim_full", "reduced_dim", "channels"] + + # Block-diagonal kernels are fully unrolled over l (LMAX constexpr) and over + # each l-block, with channels vectorized -- there is no GEMM tile to tune, so + # we only sweep the warp count / pipeline depth, keyed on the channel width. + _BD_CONFIGS = [ + triton.Config({}, num_warps=1, num_stages=1), + triton.Config({}, num_warps=2, num_stages=1), + triton.Config({}, num_warps=4, num_stages=1), + triton.Config({}, num_warps=2, num_stages=2), + triton.Config({}, num_warps=4, num_stages=2), + ] + _BD_KEY = ["channels"] + + # ================================================================== + # Triton kernels + # + # Every kernel is one fused-gather GEMM ``C_out = A @ B`` with: + # * grid = (edge, ceil(M / BLOCK_M)) -- one program per (edge, row-tile), + # * a pipelined K-loop streaming BLOCK_K of the contraction at a time, + # * the Wigner row/column gather (by ``coeff_index``) and the node-feature + # gather (by ``src``) folded into the pointer arithmetic, so neither + # ``D_to_m``/``Dt_from_m`` nor ``x[src]`` is ever materialized. + # All stores overwrite their tile (idempotent), which keeps autotuning safe. + # ================================================================== + @triton.autotune(configs=_CONFIGS, key=_KEY) + @triton.jit + def _to_local_fwd_kernel( + x_ptr, + src_ptr, + w_ptr, + idx_ptr, + out_ptr, + n_edge, + reduced_dim, + dim_full, + channels, + x_sn, + x_sd, + x_sc, + w_se, + w_sr, + w_sk, + o_se, + o_sr, + o_sc, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + """``out[e,m,c] = sum_k W[e, coeff[m], k] * x[src[e], k, c]`` (M=Dm,K=D,N=C).""" + edge = tl.program_id(0).to(tl.int64) + row = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over Dm + chan = tl.arange(0, BLOCK_N) # over C + row_mask = row < reduced_dim + chan_mask = chan < channels + + src_idx = tl.load(src_ptr + edge).to(tl.int64) + coeff_rows = tl.load(idx_ptr + row, mask=row_mask, other=0).to(tl.int64) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k0 in range(0, tl.cdiv(dim_full, BLOCK_K)): + kk = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over D + k_mask = kk < dim_full + w_tile = tl.load( + w_ptr + edge * w_se + coeff_rows[:, None] * w_sr + kk[None, :] * w_sk, + mask=row_mask[:, None] & k_mask[None, :], + other=0.0, + ) # (BLOCK_M, BLOCK_K) = W[coeff[m], k] + x_tile = tl.load( + x_ptr + src_idx * x_sn + kk[:, None] * x_sd + chan[None, :] * x_sc, + mask=k_mask[:, None] & chan_mask[None, :], + other=0.0, + ) # (BLOCK_K, BLOCK_N) = x[src, k, c] + acc = tl.dot(w_tile.to(x_tile.dtype), x_tile, acc, input_precision="ieee") + + tl.store( + out_ptr + edge * o_se + row[:, None] * o_sr + chan[None, :] * o_sc, + acc.to(out_ptr.dtype.element_ty), + mask=row_mask[:, None] & chan_mask[None, :], + ) + + @triton.autotune(configs=_CONFIGS, key=_KEY, reset_to_zero=["gx_ptr"]) + @triton.jit + def _to_local_bwd_dx_kernel( + go_ptr, + src_ptr, + w_ptr, + idx_ptr, + gx_ptr, + n_edge, + reduced_dim, + dim_full, + channels, + go_se, + go_sr, + go_sc, + w_se, + w_sr, + w_sk, + gx_sn, + gx_sd, + gx_sc, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + """``grad_x[src[e],d,c] += sum_m W[e, coeff[m], d] * grad_out[e,m,c]``. + + (M=D, K=Dm, N=C). The per-edge source gradient is atomically scattered + straight into the zero-initialized ``grad_x`` (no ``x[src]``-sized + scratch). ``reset_to_zero`` keeps the autotuner's trial runs from + polluting the accumulator. + """ + edge = tl.program_id(0).to(tl.int64) + drow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over D + chan = tl.arange(0, BLOCK_N) # over C + d_mask = drow < dim_full + chan_mask = chan < channels + + src_idx = tl.load(src_ptr + edge).to(tl.int64) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k0 in range(0, tl.cdiv(reduced_dim, BLOCK_K)): + mm = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over Dm + m_mask = mm < reduced_dim + coeff = tl.load(idx_ptr + mm, mask=m_mask, other=0).to(tl.int64) + w_tile = tl.load( + w_ptr + edge * w_se + coeff[None, :] * w_sr + drow[:, None] * w_sk, + mask=d_mask[:, None] & m_mask[None, :], + other=0.0, + ) # (BLOCK_M(d), BLOCK_K(m)) = W[coeff[m], d] + go_tile = tl.load( + go_ptr + edge * go_se + mm[:, None] * go_sr + chan[None, :] * go_sc, + mask=m_mask[:, None] & chan_mask[None, :], + other=0.0, + ) # (BLOCK_K(m), BLOCK_N(c)) + acc = tl.dot(w_tile.to(go_tile.dtype), go_tile, acc, input_precision="ieee") + + tl.atomic_add( + gx_ptr + src_idx * gx_sn + drow[:, None] * gx_sd + chan[None, :] * gx_sc, + acc, + mask=d_mask[:, None] & chan_mask[None, :], + ) + + @triton.autotune(configs=_CONFIGS, key=_KEY) + @triton.jit + def _to_local_bwd_dw_kernel( + go_ptr, + x_ptr, + src_ptr, + idx_ptr, + gw_ptr, + n_edge, + reduced_dim, + dim_full, + channels, + go_se, + go_sr, + go_sc, + x_sn, + x_sd, + x_sc, + gw_se, + gw_sr, + gw_sk, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + """``grad_W[e, coeff[m], d] = sum_c grad_out[e,m,c] * x[src[e], d, c]``. + + (M=Dm, K=C, N=D). Writes directly into rows ``coeff_index`` of the + zero-initialized ``grad_wigner``. + """ + edge = tl.program_id(0).to(tl.int64) + mrow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over Dm + dcol = tl.arange(0, BLOCK_N) # over D + m_mask = mrow < reduced_dim + d_mask = dcol < dim_full + + coeff = tl.load(idx_ptr + mrow, mask=m_mask, other=0).to(tl.int64) + src_idx = tl.load(src_ptr + edge).to(tl.int64) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k0 in range(0, tl.cdiv(channels, BLOCK_K)): + cc = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over C + c_mask = cc < channels + go_tile = tl.load( + go_ptr + edge * go_se + mrow[:, None] * go_sr + cc[None, :] * go_sc, + mask=m_mask[:, None] & c_mask[None, :], + other=0.0, + ) # (BLOCK_M(m), BLOCK_K(c)) + x_tile = tl.load( + x_ptr + src_idx * x_sn + dcol[None, :] * x_sd + cc[:, None] * x_sc, + mask=c_mask[:, None] & d_mask[None, :], + other=0.0, + ) # (BLOCK_K(c), BLOCK_N(d)) = x[src, d, c] + acc = tl.dot(go_tile.to(x_tile.dtype), x_tile, acc, input_precision="ieee") + + tl.store( + gw_ptr + edge * gw_se + coeff[:, None] * gw_sr + dcol[None, :] * gw_sk, + acc.to(gw_ptr.dtype.element_ty), + mask=m_mask[:, None] & d_mask[None, :], + ) + + # ``rotate_back`` reads the Wigner *columns* selected by ``coeff_index``. + # Gathering columns of a row-major ``(E, D, D)`` tensor is uncoalesced, so + # instead we read *dense* Wigner rows (coalesced last axis) and gather / + # scatter the small ``x_local`` through the inverse permutation + # ``inv[k] = m if coeff[m]==k else -1``. For ``mmax==lmax`` (a full + # permutation) this is the same flop count with far better memory behaviour. + @triton.autotune(configs=_CONFIGS, key=_KEY) + @triton.jit + def _back_fwd_kernel( + xl_ptr, + w_ptr, + inv_ptr, + out_ptr, + n_edge, + reduced_dim, + dim_full, + channels, + xl_se, + xl_sr, + xl_sc, + w_se, + w_sr, + w_sk, + o_se, + o_sd, + o_sc, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + """``out[e,d,c] = sum_k W[e,d,k] * x_local[e, inv[k], c]`` (M=D, K=D, N=C).""" + edge = tl.program_id(0).to(tl.int64) + drow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over D + chan = tl.arange(0, BLOCK_N) # over C + d_mask = drow < dim_full + chan_mask = chan < channels + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k0 in range(0, tl.cdiv(dim_full, BLOCK_K)): + kk = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over D (contraction) + k_mask = kk < dim_full + inv_k = tl.load(inv_ptr + kk, mask=k_mask, other=-1).to(tl.int64) + keep = inv_k >= 0 + w_tile = tl.load( + w_ptr + edge * w_se + drow[:, None] * w_sr + kk[None, :] * w_sk, + mask=d_mask[:, None] & k_mask[None, :], + other=0.0, + ) # (BLOCK_M(d), BLOCK_K(k)) = W[d, k] (k contiguous -> coalesced) + xl_tile = tl.load( + xl_ptr + edge * xl_se + inv_k[:, None] * xl_sr + chan[None, :] * xl_sc, + mask=keep[:, None] & chan_mask[None, :], + other=0.0, + ) # (BLOCK_K(k), BLOCK_N(c)) = x_local[inv[k], c] + acc = tl.dot(w_tile.to(xl_tile.dtype), xl_tile, acc, input_precision="ieee") + + tl.store( + out_ptr + edge * o_se + drow[:, None] * o_sd + chan[None, :] * o_sc, + acc.to(out_ptr.dtype.element_ty), + mask=d_mask[:, None] & chan_mask[None, :], + ) + + @triton.autotune(configs=_CONFIGS, key=_KEY) + @triton.jit + def _back_bwd_dx_kernel( + go_ptr, + w_ptr, + inv_ptr, + gxl_ptr, + n_edge, + reduced_dim, + dim_full, + channels, + go_se, + go_sd, + go_sc, + w_se, + w_sr, + w_sk, + gxl_se, + gxl_sr, + gxl_sc, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + """``grad_x_local[e, inv[k], c] = sum_d W[e,d,k] * grad_out[e,d,c]``. + + (M=D, K=D, N=C). Computes the dense ``k``-indexed gradient with coalesced + Wigner reads, then scatters each full row ``k`` into reduced row + ``inv[k]`` of ``grad_x_local``. + """ + edge = tl.program_id(0).to(tl.int64) + krow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over D + chan = tl.arange(0, BLOCK_N) # over C + k_mask = krow < dim_full + chan_mask = chan < channels + + inv_k = tl.load(inv_ptr + krow, mask=k_mask, other=-1).to(tl.int64) + keep = inv_k >= 0 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k0 in range(0, tl.cdiv(dim_full, BLOCK_K)): + dd = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over D (contraction) + d_mask = dd < dim_full + w_tile = tl.load( + w_ptr + edge * w_se + dd[None, :] * w_sr + krow[:, None] * w_sk, + mask=k_mask[:, None] & d_mask[None, :], + other=0.0, + ) # (BLOCK_M(k), BLOCK_K(d)) = W[d, k] (k contiguous -> coalesced) + go_tile = tl.load( + go_ptr + edge * go_se + dd[:, None] * go_sd + chan[None, :] * go_sc, + mask=d_mask[:, None] & chan_mask[None, :], + other=0.0, + ) # (BLOCK_K(d), BLOCK_N(c)) + acc = tl.dot(w_tile.to(go_tile.dtype), go_tile, acc, input_precision="ieee") + + tl.store( + gxl_ptr + edge * gxl_se + inv_k[:, None] * gxl_sr + chan[None, :] * gxl_sc, + acc.to(gxl_ptr.dtype.element_ty), + mask=keep[:, None] & chan_mask[None, :], + ) + + @triton.autotune(configs=_CONFIGS, key=_KEY) + @triton.jit + def _back_bwd_dw_kernel( + go_ptr, + xl_ptr, + inv_ptr, + gw_ptr, + n_edge, + reduced_dim, + dim_full, + channels, + go_se, + go_sd, + go_sc, + xl_se, + xl_sr, + xl_sc, + gw_se, + gw_sr, + gw_sk, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + """``grad_W[e,d,k] = sum_c grad_out[e,d,c] * x_local[e, inv[k], c]``. + + (M=D, K=C, N=D). Writes the dense ``(D, D)`` block of ``grad_wigner`` + with a coalesced last axis; columns ``k`` not selected by ``coeff_index`` + receive zero (``inv[k] < 0``), matching the eager column gather. + """ + edge = tl.program_id(0).to(tl.int64) + drow = tl.program_id(1) * BLOCK_M + tl.arange(0, BLOCK_M) # over D + kcol = tl.arange(0, BLOCK_N) # over D + d_mask = drow < dim_full + k_mask = kcol < dim_full + + inv_k = tl.load(inv_ptr + kcol, mask=k_mask, other=-1).to(tl.int64) + keep = inv_k >= 0 + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k0 in range(0, tl.cdiv(channels, BLOCK_K)): + cc = k0 * BLOCK_K + tl.arange(0, BLOCK_K) # over C (contraction) + c_mask = cc < channels + go_tile = tl.load( + go_ptr + edge * go_se + drow[:, None] * go_sd + cc[None, :] * go_sc, + mask=d_mask[:, None] & c_mask[None, :], + other=0.0, + ) # (BLOCK_M(d), BLOCK_K(c)) + xl_tile = tl.load( + xl_ptr + edge * xl_se + inv_k[None, :] * xl_sr + cc[:, None] * xl_sc, + mask=c_mask[:, None] & keep[None, :], + other=0.0, + ) # (BLOCK_K(c), BLOCK_N(k)) = x_local[inv[k], c] + acc = tl.dot( + go_tile.to(xl_tile.dtype), xl_tile, acc, input_precision="ieee" + ) + + tl.store( + gw_ptr + edge * gw_se + drow[:, None] * gw_sr + kcol[None, :] * gw_sk, + acc.to(gw_ptr.dtype.element_ty), + mask=d_mask[:, None] & k_mask[None, :], + ) + + # ================================================================== + # Block-diagonal kernels (mmax == 1, block-diagonal Wigner-D) + # + # The Wigner-D matrix is block-diagonal by degree ``l``: block ``l`` is the + # ``(2l+1) x (2l+1)`` sub-matrix on rows/cols ``[l^2 : (l+1)^2]`` and every + # off-(l-block) entry is exactly 0. With ``mmax == 1`` the reduced layout + # keeps, per degree ``l``, the orders ``m in {0}`` (l == 0) or + # ``{0, -1, +1}`` (l >= 1). Output coefficient ``(l, m)`` therefore contracts + # ONLY over the ``2l+1`` inputs of block ``l`` -- never the full ``D``. + # + # The m-major reduced index and the packed Wigner row/col are pure functions + # of ``(l, m, LMAX)``:: + # + # reduced index: m=0 -> l, m=-1 -> LMAX+l, m=+1 -> 2*LMAX+l + # packed (l, m): l^2 + l + m (so m=0 -> l^2+l, m=-1 -> -1, m=+1 -> +1) + # + # so the kernels need no ``coeff_index`` tensor: with ``LMAX`` a constexpr we + # fully unroll over ``l`` and over each block, contracting exactly the + # structural non-zeros (no padding, no wasted FLOPs). Channels are the + # vectorized axis (``BLOCK_C`` spans the full width ``C``), so the backward + # Wigner gradient is a single in-program ``tl.sum`` over channels. + @triton.autotune(configs=_BD_CONFIGS, key=_BD_KEY) + @triton.jit + def _bd_to_local_fwd_kernel( + x_ptr, + src_ptr, + w_ptr, + out_ptr, + n_edge, + channels, + x_sn, + x_sd, + x_sc, + w_se, + w_sr, + w_sk, + o_se, + o_sr, + o_sc, + LMAX: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + """``out[e,(l,m),c] = sum_{j} W[e, l^2+l+m, l^2+j] * x[src[e], l^2+j, c]``.""" + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < channels + src_idx = tl.load(src_ptr + edge).to(tl.int64) + + for l in tl.static_range(0, LMAX + 1): + base = l * l + r0 = base + l # packed row of order m=0 + acc0 = tl.zeros((BLOCK_C,), dtype=tl.float32) + acc_m = tl.zeros((BLOCK_C,), dtype=tl.float32) + acc_p = tl.zeros((BLOCK_C,), dtype=tl.float32) + for j in tl.static_range(0, 2 * l + 1): + col = base + j + x_vec = tl.load( + x_ptr + src_idx * x_sn + col * x_sd + chan * x_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + acc0 += tl.load(w_ptr + edge * w_se + r0 * w_sr + col * w_sk) * x_vec + if l >= 1: + acc_m += ( + tl.load(w_ptr + edge * w_se + (r0 - 1) * w_sr + col * w_sk) + * x_vec + ) + acc_p += ( + tl.load(w_ptr + edge * w_se + (r0 + 1) * w_sr + col * w_sk) + * x_vec + ) + tl.store( + out_ptr + edge * o_se + l * o_sr + chan * o_sc, + acc0.to(out_ptr.dtype.element_ty), + mask=cmask, + ) + if l >= 1: + tl.store( + out_ptr + edge * o_se + (LMAX + l) * o_sr + chan * o_sc, + acc_m.to(out_ptr.dtype.element_ty), + mask=cmask, + ) + tl.store( + out_ptr + edge * o_se + (2 * LMAX + l) * o_sr + chan * o_sc, + acc_p.to(out_ptr.dtype.element_ty), + mask=cmask, + ) + + @triton.autotune(configs=_BD_CONFIGS, key=_BD_KEY, reset_to_zero=["gx_ptr"]) + @triton.jit + def _bd_to_local_bwd_kernel( + go_ptr, + x_ptr, + src_ptr, + w_ptr, + gx_ptr, + gw_ptr, + n_edge, + channels, + go_se, + go_sr, + go_sc, + x_sn, + x_sd, + x_sc, + w_se, + w_sr, + w_sk, + gx_sn, + gx_sd, + gx_sc, + gw_se, + gw_sr, + gw_sk, + LMAX: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + """Fused block-diagonal backward of ``rotate_to_local``. + + Per edge (full channel width in one program): scatters + ``grad_x[src, l^2+j, :] += sum_m W[l^2+l+m, l^2+j] * grad_out[(l,m), :]`` + and writes ``grad_W[l^2+l+m, l^2+j] = sum_c grad_out[(l,m),c] * x[l^2+j,c]`` + for the structural non-zeros only. + """ + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < channels + src_idx = tl.load(src_ptr + edge).to(tl.int64) + + for l in tl.static_range(0, LMAX + 1): + base = l * l + r0 = base + l + go0 = tl.load( + go_ptr + edge * go_se + l * go_sr + chan * go_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + if l >= 1: + go_m = tl.load( + go_ptr + edge * go_se + (LMAX + l) * go_sr + chan * go_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + go_p = tl.load( + go_ptr + edge * go_se + (2 * LMAX + l) * go_sr + chan * go_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + for j in tl.static_range(0, 2 * l + 1): + col = base + j + x_vec = tl.load( + x_ptr + src_idx * x_sn + col * x_sd + chan * x_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + w0 = tl.load(w_ptr + edge * w_se + r0 * w_sr + col * w_sk) + gx_row = w0 * go0 + tl.store( + gw_ptr + edge * gw_se + r0 * gw_sr + col * gw_sk, + tl.sum(go0 * x_vec).to(gw_ptr.dtype.element_ty), + ) + if l >= 1: + wm = tl.load(w_ptr + edge * w_se + (r0 - 1) * w_sr + col * w_sk) + wp = tl.load(w_ptr + edge * w_se + (r0 + 1) * w_sr + col * w_sk) + gx_row += wm * go_m + wp * go_p + tl.store( + gw_ptr + edge * gw_se + (r0 - 1) * gw_sr + col * gw_sk, + tl.sum(go_m * x_vec).to(gw_ptr.dtype.element_ty), + ) + tl.store( + gw_ptr + edge * gw_se + (r0 + 1) * gw_sr + col * gw_sk, + tl.sum(go_p * x_vec).to(gw_ptr.dtype.element_ty), + ) + tl.atomic_add( + gx_ptr + src_idx * gx_sn + col * gx_sd + chan * gx_sc, + gx_row, + mask=cmask, + ) + + @triton.autotune(configs=_BD_CONFIGS, key=_BD_KEY) + @triton.jit + def _bd_back_fwd_kernel( + xl_ptr, + w_ptr, + out_ptr, + n_edge, + channels, + xl_se, + xl_sr, + xl_sc, + w_se, + w_sr, + w_sk, + o_se, + o_sd, + o_sc, + LMAX: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + """``out[e, l^2+j, c] = sum_m W[e, l^2+j, l^2+l+m] * x_local[(l,m), c]``.""" + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < channels + + for l in tl.static_range(0, LMAX + 1): + base = l * l + r0 = base + l # packed col of order m=0 + xl0 = tl.load( + xl_ptr + edge * xl_se + l * xl_sr + chan * xl_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + if l >= 1: + xl_m = tl.load( + xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + chan * xl_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + xl_p = tl.load( + xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + chan * xl_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + for j in tl.static_range(0, 2 * l + 1): + d = base + j # full packed output row + acc = tl.load(w_ptr + edge * w_se + d * w_sr + r0 * w_sk) * xl0 + if l >= 1: + acc += ( + tl.load(w_ptr + edge * w_se + d * w_sr + (r0 - 1) * w_sk) * xl_m + ) + acc += ( + tl.load(w_ptr + edge * w_se + d * w_sr + (r0 + 1) * w_sk) * xl_p + ) + tl.store( + out_ptr + edge * o_se + d * o_sd + chan * o_sc, + acc.to(out_ptr.dtype.element_ty), + mask=cmask, + ) + + @triton.autotune(configs=_BD_CONFIGS, key=_BD_KEY) + @triton.jit + def _bd_back_bwd_kernel( + go_ptr, + xl_ptr, + w_ptr, + gxl_ptr, + gw_ptr, + n_edge, + channels, + go_se, + go_sd, + go_sc, + xl_se, + xl_sr, + xl_sc, + w_se, + w_sr, + w_sk, + gxl_se, + gxl_sr, + gxl_sc, + gw_se, + gw_sr, + gw_sk, + LMAX: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + """Fused block-diagonal backward of ``rotate_back``. + + Per edge (full channel width in one program): writes + ``grad_x_local[(l,m), :] = sum_j W[l^2+j, l^2+l+m] * grad_out[l^2+j, :]`` + (no scatter -- ``x_local`` is per-edge) and + ``grad_W[l^2+j, l^2+l+m] = sum_c grad_out[l^2+j, c] * x_local[(l,m), c]``. + """ + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < channels + + for l in tl.static_range(0, LMAX + 1): + base = l * l + r0 = base + l # packed col of order m=0 + xl0 = tl.load( + xl_ptr + edge * xl_se + l * xl_sr + chan * xl_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + gxl0 = tl.zeros((BLOCK_C,), dtype=tl.float32) + if l >= 1: + xl_m = tl.load( + xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + chan * xl_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + xl_p = tl.load( + xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + chan * xl_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + gxl_m = tl.zeros((BLOCK_C,), dtype=tl.float32) + gxl_p = tl.zeros((BLOCK_C,), dtype=tl.float32) + for j in tl.static_range(0, 2 * l + 1): + d = base + j # full packed row (output of forward / grad_out row) + go_d = tl.load( + go_ptr + edge * go_se + d * go_sd + chan * go_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + w0 = tl.load(w_ptr + edge * w_se + d * w_sr + r0 * w_sk) + gxl0 += w0 * go_d + tl.store( + gw_ptr + edge * gw_se + d * gw_sr + r0 * gw_sk, + tl.sum(go_d * xl0).to(gw_ptr.dtype.element_ty), + ) + if l >= 1: + wm = tl.load(w_ptr + edge * w_se + d * w_sr + (r0 - 1) * w_sk) + wp = tl.load(w_ptr + edge * w_se + d * w_sr + (r0 + 1) * w_sk) + gxl_m += wm * go_d + gxl_p += wp * go_d + tl.store( + gw_ptr + edge * gw_se + d * gw_sr + (r0 - 1) * gw_sk, + tl.sum(go_d * xl_m).to(gw_ptr.dtype.element_ty), + ) + tl.store( + gw_ptr + edge * gw_se + d * gw_sr + (r0 + 1) * gw_sk, + tl.sum(go_d * xl_p).to(gw_ptr.dtype.element_ty), + ) + tl.store( + gxl_ptr + edge * gxl_se + l * gxl_sr + chan * gxl_sc, + gxl0.to(gxl_ptr.dtype.element_ty), + mask=cmask, + ) + if l >= 1: + tl.store( + gxl_ptr + edge * gxl_se + (LMAX + l) * gxl_sr + chan * gxl_sc, + gxl_m.to(gxl_ptr.dtype.element_ty), + mask=cmask, + ) + tl.store( + gxl_ptr + edge * gxl_se + (2 * LMAX + l) * gxl_sr + chan * gxl_sc, + gxl_p.to(gxl_ptr.dtype.element_ty), + mask=cmask, + ) + + +# ====================================================================== +# Triton launch wrappers +# ====================================================================== +def _grid_over_rows(n_edge: int, rows: int): + """Grid callable: one program per (edge, BLOCK_M-sized row tile).""" + return lambda meta: (n_edge, triton.cdiv(rows, meta["BLOCK_M"])) + + +def _inverse_index(coeff_index: Tensor, dim_full: int) -> Tensor: + """Inverse permutation ``inv[k] = m`` where ``coeff_index[m] == k`` else ``-1``. + + Maps a full packed position ``k`` back to its reduced-layout slot. Used by the + ``rotate_back`` kernels so they can read dense Wigner rows (coalesced) and + gather/scatter the small ``x_local`` instead of gathering Wigner columns. + """ + inv = torch.full((int(dim_full),), -1, dtype=torch.int64, device=coeff_index.device) + inv[coeff_index] = torch.arange( + coeff_index.numel(), dtype=torch.int64, device=coeff_index.device + ) + return inv + + +def _launch_rotate_to_local_fwd( + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + n_edge = int(src.shape[0]) + reduced_dim = int(coeff_index.shape[0]) + channels = int(x.shape[2]) + out = torch.empty((n_edge, reduced_dim, channels), dtype=x.dtype, device=x.device) + if n_edge == 0: + return out + _to_local_fwd_kernel[_grid_over_rows(n_edge, reduced_dim)]( + x, + src, + wigner, + coeff_index, + out, + n_edge, + reduced_dim, + dim_full, + channels, + x.stride(0), + x.stride(1), + x.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_N=_tile_dim(channels), + ) + return out + + +def _launch_rotate_to_local_bwd( + grad_out: Tensor, + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> tuple[Tensor, Tensor]: + n_edge = int(src.shape[0]) + reduced_dim = int(coeff_index.shape[0]) + channels = int(x.shape[2]) + grad_x = torch.zeros_like(x) + grad_wigner = torch.zeros_like(wigner) + if n_edge == 0: + return grad_x, grad_wigner + + # --- grad_x: per-edge GEMM atomically scattered into grad_x by src --- + _to_local_bwd_dx_kernel[_grid_over_rows(n_edge, dim_full)]( + grad_out, + src, + wigner, + coeff_index, + grad_x, + n_edge, + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + grad_x.stride(0), + grad_x.stride(1), + grad_x.stride(2), + BLOCK_N=_tile_dim(channels), + ) + + # --- grad_wigner: per-edge GEMM written into rows ``coeff_index`` --- + _to_local_bwd_dw_kernel[_grid_over_rows(n_edge, reduced_dim)]( + grad_out, + x, + src, + coeff_index, + grad_wigner, + n_edge, + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x.stride(0), + x.stride(1), + x.stride(2), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + BLOCK_N=_tile_dim(dim_full), + ) + return grad_x, grad_wigner + + +def _launch_rotate_back_fwd( + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + n_edge = int(x_local.shape[0]) + reduced_dim = int(coeff_index.shape[0]) + channels = int(x_local.shape[2]) + out = torch.empty( + (n_edge, dim_full, channels), dtype=x_local.dtype, device=x_local.device + ) + if n_edge == 0: + return out + inv_index = _inverse_index(coeff_index, dim_full) + _back_fwd_kernel[_grid_over_rows(n_edge, dim_full)]( + x_local, + wigner, + inv_index, + out, + n_edge, + reduced_dim, + dim_full, + channels, + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_N=_tile_dim(channels), + ) + return out + + +def _launch_rotate_back_bwd( + grad_out: Tensor, + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> tuple[Tensor, Tensor]: + n_edge = int(x_local.shape[0]) + reduced_dim = int(coeff_index.shape[0]) + channels = int(x_local.shape[2]) + grad_x_local = torch.empty_like(x_local) + grad_wigner = torch.zeros_like(wigner) + if n_edge == 0: + return grad_x_local, grad_wigner + + inv_index = _inverse_index(coeff_index, dim_full) + _back_bwd_dx_kernel[_grid_over_rows(n_edge, dim_full)]( + grad_out, + wigner, + inv_index, + grad_x_local, + n_edge, + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + grad_x_local.stride(0), + grad_x_local.stride(1), + grad_x_local.stride(2), + BLOCK_N=_tile_dim(channels), + ) + _back_bwd_dw_kernel[_grid_over_rows(n_edge, dim_full)]( + grad_out, + x_local, + inv_index, + grad_wigner, + n_edge, + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + BLOCK_N=_tile_dim(dim_full), + ) + return grad_x_local, grad_wigner + + +# ====================================================================== +# Block-diagonal launch wrappers + layout detection (mmax == 1) +# ====================================================================== +def _block_layout_lmax(coeff_index: Tensor, dim_full: int) -> int: + """Return ``lmax`` if ``(coeff_index, dim_full)`` is the m-major ``mmax=1`` + layout that the block-diagonal kernels assume, else ``-1``. + + Detection uses ONLY shapes / python ints -- never tensor *values* -- so it is + safe under ``make_fx`` / fake-tensor tracing (the production compiled + inference path). The test is: ``dim_full`` is a perfect square ``(lmax+1)^2`` + and ``Dm == 3*lmax+1``. For a fixed ``lmax`` the reduced size ``Dm`` is + strictly increasing in ``mmax`` (``lmax+1``, ``3*lmax+1``, ``5*lmax-1``, ...), + so ``Dm == 3*lmax+1`` uniquely pins ``mmax == 1``; combined with the model's + canonical ``build_m_major_index`` ordering this fully determines the layout. + """ + dim_full = int(dim_full) + root = math.isqrt(dim_full) + if root * root != dim_full: + return -1 + lmax = root - 1 + try: + numel = int(coeff_index.shape[0]) + except Exception: # pragma: no cover - exotic shape proxies + return -1 + if lmax < 1 or numel != 3 * lmax + 1: + return -1 + return lmax + + +def _launch_bd_to_local_fwd( + x: Tensor, src: Tensor, wigner: Tensor, lmax: int +) -> Tensor: + n_edge = int(src.shape[0]) + channels = int(x.shape[2]) + out = torch.empty((n_edge, 3 * lmax + 1, channels), dtype=x.dtype, device=x.device) + if n_edge == 0: + return out + _bd_to_local_fwd_kernel[(n_edge,)]( + x, + src, + wigner, + out, + n_edge, + channels, + x.stride(0), + x.stride(1), + x.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + LMAX=lmax, + BLOCK_C=_tile_dim(channels), + ) + return out + + +def _launch_bd_to_local_bwd( + grad_out: Tensor, x: Tensor, src: Tensor, wigner: Tensor, lmax: int +) -> tuple[Tensor, Tensor]: + n_edge = int(src.shape[0]) + channels = int(x.shape[2]) + grad_x = torch.zeros_like(x) + grad_wigner = torch.zeros_like(wigner) + if n_edge == 0: + return grad_x, grad_wigner + _bd_to_local_bwd_kernel[(n_edge,)]( + grad_out, + x, + src, + wigner, + grad_x, + grad_wigner, + n_edge, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x.stride(0), + x.stride(1), + x.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + grad_x.stride(0), + grad_x.stride(1), + grad_x.stride(2), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + LMAX=lmax, + BLOCK_C=_tile_dim(channels), + ) + return grad_x, grad_wigner + + +def _launch_bd_back_fwd(x_local: Tensor, wigner: Tensor, lmax: int) -> Tensor: + n_edge = int(x_local.shape[0]) + channels = int(x_local.shape[2]) + dim_full = (lmax + 1) ** 2 + out = torch.empty( + (n_edge, dim_full, channels), dtype=x_local.dtype, device=x_local.device + ) + if n_edge == 0: + return out + _bd_back_fwd_kernel[(n_edge,)]( + x_local, + wigner, + out, + n_edge, + channels, + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + LMAX=lmax, + BLOCK_C=_tile_dim(channels), + ) + return out + + +def _launch_bd_back_bwd( + grad_out: Tensor, x_local: Tensor, wigner: Tensor, lmax: int +) -> tuple[Tensor, Tensor]: + n_edge = int(x_local.shape[0]) + channels = int(x_local.shape[2]) + grad_x_local = torch.empty_like(x_local) + grad_wigner = torch.zeros_like(wigner) + if n_edge == 0: + return grad_x_local, grad_wigner + _bd_back_bwd_kernel[(n_edge,)]( + grad_out, + x_local, + wigner, + grad_x_local, + grad_wigner, + n_edge, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + grad_x_local.stride(0), + grad_x_local.stride(1), + grad_x_local.stride(2), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + LMAX=lmax, + BLOCK_C=_tile_dim(channels), + ) + return grad_x_local, grad_wigner + + +# ====================================================================== +# Dispatch helpers (triton on CUDA float, eager otherwise) +# ====================================================================== +def _use_triton(tensor: Tensor) -> bool: + return ( + TRITON_ROTATION_AVAILABLE + and tensor.is_cuda + and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32) + ) + + +def _rotate_to_local_impl( + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + if not _use_triton(x): + return rotate_to_local_reference(x, src, wigner, coeff_index, dim_full) + return _launch_rotate_to_local_fwd( + x, src.contiguous(), wigner, coeff_index.contiguous(), int(dim_full) + ) + + +def _rotate_to_local_bwd_impl( + grad_out: Tensor, + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> tuple[Tensor, Tensor]: + if not _use_triton(x): + return _rotate_to_local_bwd_eager( + grad_out, x, src, wigner, coeff_index, dim_full + ) + return _launch_rotate_to_local_bwd( + grad_out.contiguous(), + x, + src.contiguous(), + wigner, + coeff_index.contiguous(), + int(dim_full), + ) + + +def _rotate_back_impl( + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + if not _use_triton(x_local): + return rotate_back_reference(x_local, wigner, coeff_index, dim_full) + return _launch_rotate_back_fwd( + x_local, wigner, coeff_index.contiguous(), int(dim_full) + ) + + +def _rotate_back_bwd_impl( + grad_out: Tensor, + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> tuple[Tensor, Tensor]: + if not _use_triton(x_local): + return _rotate_back_bwd_eager(grad_out, x_local, wigner, coeff_index, dim_full) + return _launch_rotate_back_bwd( + grad_out.contiguous(), + x_local, + wigner, + coeff_index.contiguous(), + int(dim_full), + ) + + +# --- block-diagonal impls (mmax == 1; assume block-diagonal Wigner-D) --- +def _block_rotate_to_local_impl( + x: Tensor, src: Tensor, wigner: Tensor, lmax: int +) -> Tensor: + if not _use_triton(x): + coeff = build_m_major_index(int(lmax), 1, device=x.device) + return rotate_to_local_reference(x, src, wigner, coeff, (int(lmax) + 1) ** 2) + return _launch_bd_to_local_fwd(x, src.contiguous(), wigner, int(lmax)) + + +def _block_rotate_to_local_bwd_impl( + grad_out: Tensor, x: Tensor, src: Tensor, wigner: Tensor, lmax: int +) -> tuple[Tensor, Tensor]: + if not _use_triton(x): + coeff = build_m_major_index(int(lmax), 1, device=x.device) + return _rotate_to_local_bwd_eager( + grad_out, x, src, wigner, coeff, (int(lmax) + 1) ** 2 + ) + return _launch_bd_to_local_bwd( + grad_out.contiguous(), x, src.contiguous(), wigner, int(lmax) + ) + + +def _block_rotate_back_impl(x_local: Tensor, wigner: Tensor, lmax: int) -> Tensor: + if not _use_triton(x_local): + coeff = build_m_major_index(int(lmax), 1, device=x_local.device) + return rotate_back_reference(x_local, wigner, coeff, (int(lmax) + 1) ** 2) + return _launch_bd_back_fwd(x_local, wigner, int(lmax)) + + +def _block_rotate_back_bwd_impl( + grad_out: Tensor, x_local: Tensor, wigner: Tensor, lmax: int +) -> tuple[Tensor, Tensor]: + if not _use_triton(x_local): + coeff = build_m_major_index(int(lmax), 1, device=x_local.device) + return _rotate_back_bwd_eager( + grad_out, x_local, wigner, coeff, (int(lmax) + 1) ** 2 + ) + return _launch_bd_back_bwd(grad_out.contiguous(), x_local, wigner, int(lmax)) + + +# ====================================================================== +# Modern functional custom ops + fake + autograd registration +# ====================================================================== +# Forward and backward are both *functional* custom ops (mutates_args=()), so +# functionalization keeps the full gradient path -- including grad w.r.t. +# ``wigner`` -- intact under ``torch.compile``. + +_rotate_to_local_op = torch.library.custom_op( + "sezm_accel::rotate_to_local", mutates_args=() +)(_rotate_to_local_impl) + +_rotate_to_local_bwd_op = torch.library.custom_op( + "sezm_accel::rotate_to_local_bwd", mutates_args=() +)(_rotate_to_local_bwd_impl) + +_rotate_back_op = torch.library.custom_op("sezm_accel::rotate_back", mutates_args=())( + _rotate_back_impl +) + +_rotate_back_bwd_op = torch.library.custom_op( + "sezm_accel::rotate_back_bwd", mutates_args=() +)(_rotate_back_bwd_impl) + + +@_rotate_to_local_op.register_fake +def _(x, src, wigner, coeff_index, dim_full): + return x.new_empty((src.shape[0], coeff_index.shape[0], x.shape[2])) + + +@_rotate_to_local_bwd_op.register_fake +def _(grad_out, x, src, wigner, coeff_index, dim_full): + return torch.empty_like(x), torch.empty_like(wigner) + + +@_rotate_back_op.register_fake +def _(x_local, wigner, coeff_index, dim_full): + return x_local.new_empty((x_local.shape[0], dim_full, x_local.shape[2])) + + +@_rotate_back_bwd_op.register_fake +def _(grad_out, x_local, wigner, coeff_index, dim_full): + return torch.empty_like(x_local), torch.empty_like(wigner) + + +def _rotate_to_local_setup_context(ctx, inputs, output): + x, src, wigner, coeff_index, dim_full = inputs + ctx.save_for_backward(x, src, wigner, coeff_index) + ctx.dim_full = dim_full + + +def _rotate_to_local_backward(ctx, grad_out): + x, src, wigner, coeff_index = ctx.saved_tensors + grad_x, grad_wigner = _rotate_to_local_bwd_op( + grad_out, x, src, wigner, coeff_index, ctx.dim_full + ) + return grad_x, None, grad_wigner, None, None + + +def _rotate_back_setup_context(ctx, inputs, output): + x_local, wigner, coeff_index, dim_full = inputs + ctx.save_for_backward(x_local, wigner, coeff_index) + ctx.dim_full = dim_full + + +def _rotate_back_backward(ctx, grad_out): + x_local, wigner, coeff_index = ctx.saved_tensors + grad_x_local, grad_wigner = _rotate_back_bwd_op( + grad_out, x_local, wigner, coeff_index, ctx.dim_full + ) + return grad_x_local, grad_wigner, None, None + + +_rotate_to_local_op.register_autograd( + _rotate_to_local_backward, setup_context=_rotate_to_local_setup_context +) +_rotate_back_op.register_autograd( + _rotate_back_backward, setup_context=_rotate_back_setup_context +) + + +# --- block-diagonal custom ops (carry only ``lmax``; no coeff_index tensor) --- +_block_to_local_op = torch.library.custom_op( + "sezm_accel::rotate_to_local_block", mutates_args=() +)(_block_rotate_to_local_impl) + +_block_to_local_bwd_op = torch.library.custom_op( + "sezm_accel::rotate_to_local_block_bwd", mutates_args=() +)(_block_rotate_to_local_bwd_impl) + +_block_back_op = torch.library.custom_op( + "sezm_accel::rotate_back_block", mutates_args=() +)(_block_rotate_back_impl) + +_block_back_bwd_op = torch.library.custom_op( + "sezm_accel::rotate_back_block_bwd", mutates_args=() +)(_block_rotate_back_bwd_impl) + + +@_block_to_local_op.register_fake +def _(x, src, wigner, lmax): + return x.new_empty((src.shape[0], 3 * int(lmax) + 1, x.shape[2])) + + +@_block_to_local_bwd_op.register_fake +def _(grad_out, x, src, wigner, lmax): + return torch.empty_like(x), torch.empty_like(wigner) + + +@_block_back_op.register_fake +def _(x_local, wigner, lmax): + return x_local.new_empty((x_local.shape[0], (int(lmax) + 1) ** 2, x_local.shape[2])) + + +@_block_back_bwd_op.register_fake +def _(grad_out, x_local, wigner, lmax): + return torch.empty_like(x_local), torch.empty_like(wigner) + + +def _block_to_local_setup_context(ctx, inputs, output): + x, src, wigner, lmax = inputs + ctx.save_for_backward(x, src, wigner) + ctx.lmax = lmax + + +def _block_to_local_backward(ctx, grad_out): + x, src, wigner = ctx.saved_tensors + grad_x, grad_wigner = _block_to_local_bwd_op(grad_out, x, src, wigner, ctx.lmax) + return grad_x, None, grad_wigner, None + + +def _block_back_setup_context(ctx, inputs, output): + x_local, wigner, lmax = inputs + ctx.save_for_backward(x_local, wigner) + ctx.lmax = lmax + + +def _block_back_backward(ctx, grad_out): + x_local, wigner = ctx.saved_tensors + grad_x_local, grad_wigner = _block_back_bwd_op(grad_out, x_local, wigner, ctx.lmax) + return grad_x_local, grad_wigner, None + + +_block_to_local_op.register_autograd( + _block_to_local_backward, setup_context=_block_to_local_setup_context +) +_block_back_op.register_autograd( + _block_back_backward, setup_context=_block_back_setup_context +) + + +# ====================================================================== +# Public API +# ====================================================================== +def rotate_to_local( + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + """Fused ``global -> edge-local reduced`` rotation ``bmm(D_to_m, x[src])``. + + Parameters + ---------- + x + Node features, shape ``(N, D, C)``. + src + Source-node index per edge, shape ``(E,)`` int64. + wigner + Per-edge packed Wigner-D matrices, shape ``(E, Dw, Dw)`` with + ``Dw >= dim_full``. + coeff_index + m-major reduced-layout row indices, shape ``(Dm,)`` int64. + dim_full + Full packed SO(3) dimension ``D = (lmax+1)**2``. + + Returns + ------- + Tensor + Rotated reduced-layout edge features, shape ``(E, Dm, C)``. + + Notes + ----- + When ``(coeff_index, dim_full)`` is the m-major ``mmax=1`` layout, this + auto-selects the block-diagonal kernel (which assumes a block-diagonal + Wigner-D, as produced by the model); otherwise it uses the dense kernel. + """ + lmax = _block_layout_lmax(coeff_index, dim_full) + if lmax >= 0: + return _block_to_local_op(x, src, wigner, lmax) + return _rotate_to_local_op(x, src, wigner, coeff_index, int(dim_full)) + + +def rotate_back( + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + """Fused ``edge-local reduced -> global`` rotation ``bmm(Dt_from_m, x_local)``. + + Parameters + ---------- + x_local + Reduced-layout edge features, shape ``(E, Dm, C)``. + wigner + Per-edge packed Wigner-D matrices, shape ``(E, Dw, Dw)`` with + ``Dw >= dim_full``. + coeff_index + m-major reduced-layout column indices, shape ``(Dm,)`` int64. + dim_full + Full packed SO(3) dimension ``D = (lmax+1)**2``. + + Returns + ------- + Tensor + Lifted global-layout edge features, shape ``(E, D, C)``. + + Notes + ----- + Auto-selects the block-diagonal kernel for the m-major ``mmax=1`` layout, + else the dense kernel (see ``rotate_to_local``). + """ + lmax = _block_layout_lmax(coeff_index, dim_full) + if lmax >= 0: + return _block_back_op(x_local, wigner, lmax) + return _rotate_back_op(x_local, wigner, coeff_index, int(dim_full)) + + +# --- Explicit entry points (benchmarking / forcing a path) --- +def rotate_to_local_dense( + x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int +) -> Tensor: + """Force the dense (general-layout) ``rotate_to_local`` kernel.""" + return _rotate_to_local_op(x, src, wigner, coeff_index, int(dim_full)) + + +def rotate_back_dense( + x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int +) -> Tensor: + """Force the dense (general-layout) ``rotate_back`` kernel.""" + return _rotate_back_op(x_local, wigner, coeff_index, int(dim_full)) + + +def rotate_to_local_block( + x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int +) -> Tensor: + """Force the block-diagonal ``rotate_to_local`` kernel (requires mmax=1 layout).""" + lmax = _block_layout_lmax(coeff_index, dim_full) + if lmax < 0: + raise ValueError( + "rotate_to_local_block requires the m-major mmax=1 coefficient layout." + ) + return _block_to_local_op(x, src, wigner, lmax) + + +def rotate_back_block( + x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int +) -> Tensor: + """Force the block-diagonal ``rotate_back`` kernel (requires mmax=1 layout).""" + lmax = _block_layout_lmax(coeff_index, dim_full) + if lmax < 0: + raise ValueError( + "rotate_back_block requires the m-major mmax=1 coefficient layout." + ) + return _block_back_op(x_local, wigner, lmax) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index 44de7260a0..b1b98c6f37 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -601,7 +601,6 @@ def _sezm_structure_key(model: SeZMModel) -> tuple[Any, ...]: fitting_key = _module_shared_key(fitting) descriptor_state = ( _int_pair_tuple(descriptor.exclude_types), - bool(descriptor.use_triton), bool(descriptor.use_env_seed), bool(descriptor.use_gie), bool(descriptor.random_gamma), diff --git a/doc/model/dpa4.md b/doc/model/dpa4.md index f6a6638ef2..63d0e4538c 100644 --- a/doc/model/dpa4.md +++ b/doc/model/dpa4.md @@ -464,6 +464,21 @@ leave energy and force MAE nearly unchanged while making the potential energy surface less smooth. For less smoothness-sensitive evaluation or screening workloads, `DP_TF32_INFER=1` or `2` may be useful for improving throughput. +`DP_TRITON_INFER` enables fused block-diagonal Triton kernels for the SO(2) +Wigner-D rotation. It applies to evaluation and inference on CUDA in eval mode +only and is disabled by default: + +```bash +export DP_TRITON_INFER=1 +``` + +The kernels operate on the block-diagonal (by degree `l`) structure of the +Wigner-D matrix and are numerically equivalent to the default dense rotation up +to floating-point rounding. They retain full float32 accumulation regardless of +`DP_TF32_INFER` and are therefore appropriate for smoothness-sensitive +workflows. They are compatible with the compile path (`DP_COMPILE_INFER=1`) and +reduce both latency and peak memory. + ### Hardware selection DPA4/SeZM is designed for fp32 training and inference. Hardware selection diff --git a/source/tests/pt/model/test_descriptor_sezm_triton.py b/source/tests/pt/model/test_descriptor_sezm_triton.py index a0e3d44483..e09794e32f 100644 --- a/source/tests/pt/model/test_descriptor_sezm_triton.py +++ b/source/tests/pt/model/test_descriptor_sezm_triton.py @@ -1,960 +1,181 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for the block-diagonal Triton SO(2)/Wigner rotation kernels +(opt-in via ``DP_TRITON_INFER``). + +Two properties are checked against the eager PyTorch reference: + +1. Numerical correctness of ``rotate_to_local`` / ``rotate_back`` (forward and + backward) across ``lmax`` 2-5 with ``mmax == 1`` -- the only layout the block + kernels accept. The Wigner-D is block-diagonal by ``l``, so the kernel touches + only the structural non-zeros; the gradient w.r.t. the Wigner therefore matches + the reference on the block entries (the off-block reference gradient is + structurally discarded by the model, which builds the Wigner with zero + off-block entries). +2. ``torch.compile`` composability: gradients through the functional + ``custom_op`` must match the eager reference when the op is traced under + ``make_fx`` -- the autograd path that compiled inference uses to obtain + forces. +""" + import unittest import torch -from deepmd.pt.model.descriptor.sezm_nn import ( - C3CutoffEnvelope, - InnerClamp, - RadialBasis, +from deepmd.pt.model.descriptor.sezm_nn.indexing import ( build_m_major_index, - project_D_to_m, - project_Dt_from_m, -) -from deepmd.pt.model.descriptor.sezm_nn.triton import ( - SEZM_TRITON_AVAILABLE, - TritonRotationMode, - edge_geometry_rbf_triton, - resolve_triton_rotation_mode, - rotate_back_triton, - rotate_to_local_triton, -) - -TRITON_CUDA_AVAILABLE = SEZM_TRITON_AVAILABLE and torch.cuda.is_available() - - -class TestSeZMTritonDispatch(unittest.TestCase): - """Validate the SeZM Triton dispatch policy.""" - - def test_resolve_rotation_mode_covers_small_generic_and_fallback(self) -> None: - """Dispatch policy should cover small kernels, generic kernels, and fallback.""" - self.assertEqual( - resolve_triton_rotation_mode(dim_full=1, reduced_dim=1), - TritonRotationMode.SMALL_LE1, - ) - self.assertEqual( - resolve_triton_rotation_mode(dim_full=4, reduced_dim=4), - TritonRotationMode.SMALL_LE1, - ) - self.assertEqual( - resolve_triton_rotation_mode(dim_full=9, reduced_dim=7), - TritonRotationMode.SMALL_L2, - ) - self.assertEqual( - resolve_triton_rotation_mode(dim_full=16, reduced_dim=10), - TritonRotationMode.SMALL_L3, - ) - self.assertEqual( - resolve_triton_rotation_mode(dim_full=25, reduced_dim=15), - TritonRotationMode.EAGER_REFERENCE, - ) - self.assertEqual( - resolve_triton_rotation_mode(dim_full=25, reduced_dim=16), - TritonRotationMode.GENERIC_TILED, - ) - - -@unittest.skipUnless( - TRITON_CUDA_AVAILABLE, - "SeZM Triton rotation tests require CUDA and Triton.", + get_so3_dim_of_lmax, ) -class TestSeZMTritonEdgeGeometryRBF(unittest.TestCase): - """Validate the Triton edge geometry/RBF chain against eager reference.""" - - def _eager_reference( - self, - *, - coord_flat: torch.Tensor, - center_idx: torch.Tensor, - neighbor_idx: torch.Tensor, - edge_envelope: C3CutoffEnvelope, - radial_basis: RadialBasis, - inner_clamp: InnerClamp | None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute the eager reference geometry/RBF chain.""" - center_pos = coord_flat.index_select(0, center_idx) - neighbor_pos = coord_flat.index_select(0, neighbor_idx) - edge_vec = neighbor_pos - center_pos - edge_len = torch.sqrt( - torch.sum(edge_vec * edge_vec, dim=-1, keepdim=True) + 1.0e-14 - ) - if inner_clamp is not None: - clamped = inner_clamp(edge_len) - edge_vec = edge_vec * (clamped / edge_len) - edge_len = clamped - edge_env = edge_envelope(edge_len) - edge_rbf = radial_basis(edge_len) - return edge_vec, edge_len, edge_env, edge_rbf - - def test_edge_geometry_rbf_matches_reference_forward_backward(self) -> None: - """Compare fused geometry/RBF chain with eager gather/clamp/envelope/rbf.""" - device = torch.device("cuda") - dtype = torch.float32 - coord_ref = torch.randn( - 12, - 3, - device=device, - dtype=dtype, - requires_grad=True, - ) - coord_triton = coord_ref.detach().clone().requires_grad_(True) - center_idx = torch.randint(0, 12, (9,), device=device, dtype=torch.long) - neighbor_idx = torch.randint(0, 12, (9,), device=device, dtype=torch.long) - edge_envelope = C3CutoffEnvelope(rcut=6.0, exponent=5).to(device) - radial_ref = RadialBasis(rcut=6.0, n_radial=6, dtype=dtype, exponent=7).to( - device - ) - radial_triton = RadialBasis(rcut=6.0, n_radial=6, dtype=dtype, exponent=7).to( - device - ) - radial_triton.load_state_dict(radial_ref.state_dict()) - - out_ref = self._eager_reference( - coord_flat=coord_ref, - center_idx=center_idx, - neighbor_idx=neighbor_idx, - edge_envelope=edge_envelope, - radial_basis=radial_ref, - inner_clamp=None, - ) - out_triton = edge_geometry_rbf_triton( - coord_flat=coord_triton, - center_coord_index=center_idx, - neighbor_coord_index=neighbor_idx, - edge_envelope=edge_envelope, - radial_basis=radial_triton, - eps=1.0e-7, - inner_clamp=None, - ) - for ref, tri in zip(out_ref, out_triton, strict=True): - torch.testing.assert_close(tri, ref, atol=1.0e-5, rtol=1.0e-5) - - grad_out = tuple(torch.randn_like(ref) for ref in out_ref) - grad_coord_ref, grad_freq_ref = torch.autograd.grad( - out_ref, - (coord_ref, radial_ref.adam_freqs), - grad_outputs=grad_out, - ) - grad_coord_triton, grad_freq_triton = torch.autograd.grad( - out_triton, - (coord_triton, radial_triton.adam_freqs), - grad_outputs=grad_out, - ) - torch.testing.assert_close( - grad_coord_triton, - grad_coord_ref, - atol=2.0e-5, - rtol=2.0e-5, - ) - torch.testing.assert_close( - grad_freq_triton, - grad_freq_ref, - atol=2.0e-5, - rtol=2.0e-5, - ) - - def test_edge_geometry_rbf_matches_reference_with_inner_clamp(self) -> None: - """Compare the clamped Triton path with eager reference.""" - device = torch.device("cuda") - dtype = torch.float32 - coord_ref = torch.randn( - 10, - 3, - device=device, - dtype=dtype, - requires_grad=True, - ) - coord_triton = coord_ref.detach().clone().requires_grad_(True) - center_idx = torch.randint(0, 10, (7,), device=device, dtype=torch.long) - neighbor_idx = torch.randint(0, 10, (7,), device=device, dtype=torch.long) - edge_envelope = C3CutoffEnvelope(rcut=6.0, exponent=5).to(device) - radial_ref = RadialBasis(rcut=6.0, n_radial=4, dtype=dtype, exponent=7).to( - device - ) - radial_triton = RadialBasis(rcut=6.0, n_radial=4, dtype=dtype, exponent=7).to( - device - ) - radial_triton.load_state_dict(radial_ref.state_dict()) - inner_clamp = InnerClamp(0.9, 1.3).to(device) - - out_ref = self._eager_reference( - coord_flat=coord_ref, - center_idx=center_idx, - neighbor_idx=neighbor_idx, - edge_envelope=edge_envelope, - radial_basis=radial_ref, - inner_clamp=inner_clamp, - ) - out_triton = edge_geometry_rbf_triton( - coord_flat=coord_triton, - center_coord_index=center_idx, - neighbor_coord_index=neighbor_idx, - edge_envelope=edge_envelope, - radial_basis=radial_triton, - eps=1.0e-7, - inner_clamp=inner_clamp, - ) - for ref, tri in zip(out_ref, out_triton, strict=True): - torch.testing.assert_close(tri, ref, atol=2.0e-5, rtol=2.0e-5) - - loss_ref = sum(x.square().sum() for x in out_ref) - loss_triton = sum(x.square().sum() for x in out_triton) - grad_coord_ref, grad_freq_ref = torch.autograd.grad( - loss_ref, - (coord_ref, radial_ref.adam_freqs), - ) - grad_coord_triton, grad_freq_triton = torch.autograd.grad( - loss_triton, - (coord_triton, radial_triton.adam_freqs), - ) - torch.testing.assert_close( - grad_coord_triton, - grad_coord_ref, - atol=3.0e-5, - rtol=3.0e-5, - ) - torch.testing.assert_close( - grad_freq_triton, - grad_freq_ref, - atol=3.0e-5, - rtol=3.0e-5, - ) - - -@unittest.skipUnless( - TRITON_CUDA_AVAILABLE, - "SeZM Triton rotation tests require CUDA and Triton.", +from deepmd.pt.model.descriptor.sezm_nn.triton.so2_rotation import ( + TRITON_ROTATION_AVAILABLE, + rotate_back, + rotate_back_reference, + rotate_to_local, + rotate_to_local_reference, ) -class TestSeZMTritonSO2(unittest.TestCase): - """Validate Triton SO(2) rotation kernels against the eager reference path.""" - def _require_cuda_bfloat16(self) -> None: - """Skip the mixed-precision Triton tests when CUDA bf16 is unavailable.""" - if not torch.cuda.is_bf16_supported(): - self.skipTest("CUDA bfloat16 is required for mixed-precision Triton tests.") +_CUDA = torch.cuda.is_available() - def test_rotate_to_local_matches_reference_forward_backward(self) -> None: - """Compare fused Triton rotate-to-local with projected eager matmul.""" - device = torch.device("cuda") - dtype = torch.float32 - n_node = 7 - n_edge = 11 - channels = 8 - for lmax, mmax in ((2, 1), (3, 1)): - dim_full = (lmax + 1) ** 2 - coeff_index = build_m_major_index(lmax, mmax, device=device) - src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) - x_ref = torch.randn( - n_node, - dim_full, - channels, - device=device, - dtype=dtype, - requires_grad=True, - ) - wigner_ref = torch.randn( - n_edge, - dim_full, - dim_full, - device=device, - dtype=dtype, - requires_grad=True, - ) - x_triton = x_ref.detach().clone().requires_grad_(True) - wigner_triton = wigner_ref.detach().clone().requires_grad_(True) - out_ref = torch.bmm( - project_D_to_m( - D_full=wigner_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=lmax, - key_mmax=mmax, - ), - x_ref.index_select(0, src), - ) - out_triton = rotate_to_local_triton( - x=x_triton, - src=src, - wigner=wigner_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close(out_triton, out_ref, atol=1.0e-5, rtol=1.0e-5) - - grad_out = torch.randn_like(out_ref) - grad_x_ref, grad_wigner_ref = torch.autograd.grad( - out_ref, - (x_ref, wigner_ref), - grad_outputs=grad_out, - ) - grad_x_triton, grad_wigner_triton = torch.autograd.grad( - out_triton, - (x_triton, wigner_triton), - grad_outputs=grad_out, - ) - torch.testing.assert_close( - grad_x_triton, - grad_x_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - torch.testing.assert_close( - grad_wigner_triton, - grad_wigner_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - - def test_rotate_back_matches_reference_forward_backward(self) -> None: - """Compare fused Triton rotate-back with projected eager matmul.""" - device = torch.device("cuda") - dtype = torch.float32 - n_edge = 11 - channels = 8 - for lmax, mmax in ((2, 1), (3, 1)): - dim_full = (lmax + 1) ** 2 - coeff_index = build_m_major_index(lmax, mmax, device=device) - reduced_dim = int(coeff_index.numel()) - x_local_ref = torch.randn( - n_edge, - reduced_dim, - channels, - device=device, - dtype=dtype, - requires_grad=True, - ) - wigner_ref = torch.randn( - n_edge, - dim_full, - dim_full, - device=device, - dtype=dtype, - requires_grad=True, - ) - x_local_triton = x_local_ref.detach().clone().requires_grad_(True) - wigner_triton = wigner_ref.detach().clone().requires_grad_(True) - - out_ref = torch.bmm( - project_Dt_from_m( - Dt_full=wigner_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=lmax, - key_mmax=mmax, - ), - x_local_ref, - ) - out_triton = rotate_back_triton( - x_local=x_local_triton, - wigner=wigner_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close(out_triton, out_ref, atol=1.0e-5, rtol=1.0e-5) - - grad_out = torch.randn_like(out_ref) - grad_x_ref, grad_wigner_ref = torch.autograd.grad( - out_ref, - (x_local_ref, wigner_ref), - grad_outputs=grad_out, - ) - grad_x_triton, grad_wigner_triton = torch.autograd.grad( - out_triton, - (x_local_triton, wigner_triton), - grad_outputs=grad_out, - ) - torch.testing.assert_close( - grad_x_triton, - grad_x_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - torch.testing.assert_close( - grad_wigner_triton, - grad_wigner_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - - def test_rotate_to_local_matches_mixed_precision_reference(self) -> None: - """Compare Triton rotate-to-local with bf16 activations and fp32 Wigner.""" - self._require_cuda_bfloat16() - device = torch.device("cuda") - x_dtype = torch.bfloat16 - wigner_dtype = torch.float32 - n_node = 7 - n_edge = 11 - channels = 8 - for lmax, mmax in ((2, 1), (3, 1)): - dim_full = (lmax + 1) ** 2 - coeff_index = build_m_major_index(lmax, mmax, device=device) - src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) - x_ref = torch.randn( - n_node, - dim_full, - channels, - device=device, - dtype=x_dtype, - requires_grad=True, - ) - wigner_ref = torch.randn( - n_edge, - dim_full, - dim_full, - device=device, - dtype=wigner_dtype, - requires_grad=True, - ) - x_triton = x_ref.detach().clone().requires_grad_(True) - wigner_triton = wigner_ref.detach().clone().requires_grad_(True) - - out_ref = torch.bmm( - project_D_to_m( - D_full=wigner_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=lmax, - key_mmax=mmax, - ).to(dtype=x_dtype), - x_ref.index_select(0, src), - ) - out_triton = rotate_to_local_triton( - x=x_triton, - src=src, - wigner=wigner_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close(out_triton, out_ref, atol=3.0e-2, rtol=3.0e-2) - - grad_out = torch.randn_like(out_ref) - grad_x_ref, grad_wigner_ref = torch.autograd.grad( - out_ref, - (x_ref, wigner_ref), - grad_outputs=grad_out, - ) - grad_x_triton, grad_wigner_triton = torch.autograd.grad( - out_triton, - (x_triton, wigner_triton), - grad_outputs=grad_out, - ) - torch.testing.assert_close( - grad_x_triton, - grad_x_ref, - atol=3.0e-2, - rtol=3.0e-2, - ) - torch.testing.assert_close( - grad_wigner_triton, - grad_wigner_ref, - atol=3.0e-2, - rtol=3.0e-2, - ) - - def test_rotate_back_matches_mixed_precision_reference(self) -> None: - """Compare Triton rotate-back with bf16 activations and fp32 Wigner.""" - self._require_cuda_bfloat16() - device = torch.device("cuda") - x_dtype = torch.bfloat16 - wigner_dtype = torch.float32 - n_edge = 11 - channels = 8 - for lmax, mmax in ((2, 1), (3, 1)): - dim_full = (lmax + 1) ** 2 - coeff_index = build_m_major_index(lmax, mmax, device=device) - reduced_dim = int(coeff_index.numel()) - x_local_ref = torch.randn( - n_edge, - reduced_dim, - channels, - device=device, - dtype=x_dtype, - requires_grad=True, - ) - wigner_ref = torch.randn( - n_edge, - dim_full, - dim_full, - device=device, - dtype=wigner_dtype, - requires_grad=True, - ) - x_local_triton = x_local_ref.detach().clone().requires_grad_(True) - wigner_triton = wigner_ref.detach().clone().requires_grad_(True) - - out_ref = torch.bmm( - project_Dt_from_m( - Dt_full=wigner_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=lmax, - key_mmax=mmax, - ).to(dtype=x_dtype), - x_local_ref, - ) - out_triton = rotate_back_triton( - x_local=x_local_triton, - wigner=wigner_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close(out_triton, out_ref, atol=3.0e-2, rtol=3.0e-2) - - grad_out = torch.randn_like(out_ref) - grad_x_ref, grad_wigner_ref = torch.autograd.grad( - out_ref, - (x_local_ref, wigner_ref), - grad_outputs=grad_out, - ) - grad_x_triton, grad_wigner_triton = torch.autograd.grad( - out_triton, - (x_local_triton, wigner_triton), - grad_outputs=grad_out, - ) - torch.testing.assert_close( - grad_x_triton, - grad_x_ref, - atol=3.0e-2, - rtol=3.0e-2, - ) - torch.testing.assert_close( - grad_wigner_triton, - grad_wigner_ref, - atol=3.0e-2, - rtol=3.0e-2, - ) - - def test_rotate_to_local_matches_bfloat16_autocast_semantics(self) -> None: - """Use the activation dtype selected by AMP for Triton rotate-to-local.""" - self._require_cuda_bfloat16() - device = torch.device("cuda") - act_dtype = torch.bfloat16 - wigner_dtype = torch.float32 - n_node = 7 - n_edge = 11 - dim_full = 16 - channels = 8 - coeff_index = build_m_major_index(3, 1, device=device) - src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) - x_ref = torch.randn( - n_node, - dim_full, - channels, - device=device, - dtype=act_dtype, - requires_grad=True, - ) - wigner_ref = torch.randn( +def _block_diagonal_wigner(n_edge, lmax, device, dtype, generator): + """Random Wigner-D that is block-diagonal by ``l`` (block ``l`` occupies + rows/cols ``[l**2 : (l+1)**2]``); off-block entries are exactly zero. + """ + dim = get_so3_dim_of_lmax(lmax) + wigner = torch.zeros(n_edge, dim, dim, device=device, dtype=dtype) + for ll in range(lmax + 1): + start, end = ll * ll, (ll + 1) ** 2 + wigner[:, start:end, start:end] = torch.randn( n_edge, - dim_full, - dim_full, - device=device, - dtype=wigner_dtype, - requires_grad=True, - ) - x_triton = x_ref.detach().clone().requires_grad_(True) - wigner_triton = wigner_ref.detach().clone().requires_grad_(True) - - D_m_prime = project_D_to_m( - D_full=wigner_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=3, - key_mmax=1, - ).to(dtype=act_dtype) - out_ref = torch.bmm(D_m_prime, x_ref.index_select(0, src)) - out_triton = rotate_to_local_triton( - x=x_triton, - src=src, - wigner=wigner_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close(out_triton, out_ref, atol=5.0e-2, rtol=5.0e-2) - - grad_out = torch.randn_like(out_ref) - grad_x_ref, grad_wigner_ref = torch.autograd.grad( - out_ref, - (x_ref, wigner_ref), - grad_outputs=grad_out, - ) - grad_x_triton, grad_wigner_triton = torch.autograd.grad( - out_triton, - (x_triton, wigner_triton), - grad_outputs=grad_out, - ) - torch.testing.assert_close( - grad_x_triton, - grad_x_ref, - atol=5.0e-2, - rtol=5.0e-2, - ) - torch.testing.assert_close( - grad_wigner_triton, - grad_wigner_ref, - atol=5.0e-2, - rtol=5.0e-2, - ) - - def test_rotate_back_matches_bfloat16_autocast_semantics(self) -> None: - """Use the activation dtype selected by AMP for Triton rotate-back.""" - self._require_cuda_bfloat16() - device = torch.device("cuda") - act_dtype = torch.bfloat16 - wigner_dtype = torch.float32 - n_edge = 11 - dim_full = 16 - channels = 8 - coeff_index = build_m_major_index(3, 1, device=device) - reduced_dim = int(coeff_index.numel()) - x_local_ref = torch.randn( - n_edge, - reduced_dim, - channels, - device=device, - dtype=act_dtype, - requires_grad=True, - ) - wigner_ref = torch.randn( - n_edge, - dim_full, - dim_full, - device=device, - dtype=wigner_dtype, - requires_grad=True, - ) - x_local_triton = x_local_ref.detach().clone().requires_grad_(True) - wigner_triton = wigner_ref.detach().clone().requires_grad_(True) - - Dt_from_m = project_Dt_from_m( - Dt_full=wigner_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=3, - key_mmax=1, - ).to(dtype=act_dtype) - out_ref = torch.bmm(Dt_from_m, x_local_ref) - out_triton = rotate_back_triton( - x_local=x_local_triton, - wigner=wigner_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close(out_triton, out_ref, atol=5.0e-2, rtol=5.0e-2) - - grad_out = torch.randn_like(out_ref) - grad_x_ref, grad_wigner_ref = torch.autograd.grad( - out_ref, - (x_local_ref, wigner_ref), - grad_outputs=grad_out, - ) - grad_x_triton, grad_wigner_triton = torch.autograd.grad( - out_triton, - (x_local_triton, wigner_triton), - grad_outputs=grad_out, - ) - torch.testing.assert_close( - grad_x_triton, - grad_x_ref, - atol=5.0e-2, - rtol=5.0e-2, - ) - torch.testing.assert_close( - grad_wigner_triton, - grad_wigner_ref, - atol=5.0e-2, - rtol=5.0e-2, - ) - - def test_generic_small_k_falls_back_to_reference_forward_backward(self) -> None: - """Fallback to eager bmm when generic Triton tiles would have K < 16.""" - device = torch.device("cuda") - dtype = torch.float32 - lmax, mmax = 4, 0 - dim_full = (lmax + 1) ** 2 - n_node = 7 - n_edge = 11 - channels = 8 - coeff_index = build_m_major_index(lmax, mmax, device=device) - self.assertLess(int(coeff_index.numel()), 16) - - src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) - x_ref = torch.randn( - n_node, - dim_full, - channels, + end - start, + end - start, device=device, dtype=dtype, - requires_grad=True, - ) - wigner_ref = torch.randn( - n_edge, - dim_full, - dim_full, - device=device, - dtype=dtype, - requires_grad=True, - ) - x_triton = x_ref.detach().clone().requires_grad_(True) - wigner_triton = wigner_ref.detach().clone().requires_grad_(True) - - out_ref = torch.bmm( - project_D_to_m( - D_full=wigner_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=lmax, - key_mmax=mmax, - ), - x_ref.index_select(0, src), - ) - out_triton = rotate_to_local_triton( - x=x_triton, - src=src, - wigner=wigner_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close(out_triton, out_ref, atol=1.0e-5, rtol=1.0e-5) - - grad_out = torch.randn_like(out_ref) - grad_x_ref, grad_wigner_ref = torch.autograd.grad( - out_ref, - (x_ref, wigner_ref), - grad_outputs=grad_out, - ) - grad_x_triton, grad_wigner_triton = torch.autograd.grad( - out_triton, - (x_triton, wigner_triton), - grad_outputs=grad_out, - ) - torch.testing.assert_close( - grad_x_triton, - grad_x_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - torch.testing.assert_close( - grad_wigner_triton, - grad_wigner_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - - x_local_ref = torch.randn( - n_edge, - int(coeff_index.numel()), - channels, - device=device, - dtype=dtype, - requires_grad=True, - ) - wigner_back_ref = torch.randn( - n_edge, - dim_full, - dim_full, - device=device, - dtype=dtype, - requires_grad=True, - ) - x_local_triton = x_local_ref.detach().clone().requires_grad_(True) - wigner_back_triton = wigner_back_ref.detach().clone().requires_grad_(True) - - out_back_ref = torch.bmm( - project_Dt_from_m( - Dt_full=wigner_back_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=lmax, - key_mmax=mmax, - ), - x_local_ref, - ) - out_back_triton = rotate_back_triton( - x_local=x_local_triton, - wigner=wigner_back_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close( - out_back_triton, - out_back_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - - grad_back = torch.randn_like(out_back_ref) - grad_x_local_ref, grad_wigner_back_ref = torch.autograd.grad( - out_back_ref, - (x_local_ref, wigner_back_ref), - grad_outputs=grad_back, - ) - grad_x_local_triton, grad_wigner_back_triton = torch.autograd.grad( - out_back_triton, - (x_local_triton, wigner_back_triton), - grad_outputs=grad_back, - ) - torch.testing.assert_close( - grad_x_local_triton, - grad_x_local_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - torch.testing.assert_close( - grad_wigner_back_triton, - grad_wigner_back_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - - def test_generic_large_k_matches_reference_forward_backward(self) -> None: - """Exercise the true generic Triton path when reduced_dim >= 16.""" - device = torch.device("cuda") - dtype = torch.float32 - n_node = 7 - n_edge = 11 - channels = 8 - for lmax, mmax in ((4, 2), (4, 4), (5, 2)): - dim_full = (lmax + 1) ** 2 - coeff_index = build_m_major_index(lmax, mmax, device=device) - self.assertGreaterEqual(int(coeff_index.numel()), 16) - - src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) - x_ref = torch.randn( - n_node, - dim_full, - channels, - device=device, - dtype=dtype, - requires_grad=True, - ) - wigner_ref = torch.randn( - n_edge, - dim_full, - dim_full, - device=device, - dtype=dtype, - requires_grad=True, - ) - x_triton = x_ref.detach().clone().requires_grad_(True) - wigner_triton = wigner_ref.detach().clone().requires_grad_(True) - - out_ref = torch.bmm( - project_D_to_m( - D_full=wigner_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=lmax, - key_mmax=mmax, - ), - x_ref.index_select(0, src), - ) - out_triton = rotate_to_local_triton( - x=x_triton, - src=src, - wigner=wigner_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close(out_triton, out_ref, atol=1.0e-5, rtol=1.0e-5) - - grad_out = torch.randn_like(out_ref) - grad_x_ref, grad_wigner_ref = torch.autograd.grad( - out_ref, - (x_ref, wigner_ref), - grad_outputs=grad_out, - ) - grad_x_triton, grad_wigner_triton = torch.autograd.grad( - out_triton, - (x_triton, wigner_triton), - grad_outputs=grad_out, - ) - torch.testing.assert_close( - grad_x_triton, - grad_x_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - torch.testing.assert_close( - grad_wigner_triton, - grad_wigner_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - - x_local_ref = torch.randn( - n_edge, - int(coeff_index.numel()), - channels, - device=device, - dtype=dtype, - requires_grad=True, - ) - wigner_back_ref = torch.randn( - n_edge, - dim_full, - dim_full, - device=device, - dtype=dtype, - requires_grad=True, - ) - x_local_triton = x_local_ref.detach().clone().requires_grad_(True) - wigner_back_triton = wigner_back_ref.detach().clone().requires_grad_(True) - - out_back_ref = torch.bmm( - project_Dt_from_m( - Dt_full=wigner_back_ref, - coeff_index_m=coeff_index, - ebed_dim_full=dim_full, - cache=None, - key_lmax=lmax, - key_mmax=mmax, - ), - x_local_ref, - ) - out_back_triton = rotate_back_triton( - x_local=x_local_triton, - wigner=wigner_back_triton, - coeff_index=coeff_index, - dim_full=dim_full, - ) - torch.testing.assert_close( - out_back_triton, - out_back_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - - grad_back = torch.randn_like(out_back_ref) - grad_x_local_ref, grad_wigner_back_ref = torch.autograd.grad( - out_back_ref, - (x_local_ref, wigner_back_ref), - grad_outputs=grad_back, - ) - grad_x_local_triton, grad_wigner_back_triton = torch.autograd.grad( - out_back_triton, - (x_local_triton, wigner_back_triton), - grad_outputs=grad_back, - ) - torch.testing.assert_close( - grad_x_local_triton, - grad_x_local_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) - torch.testing.assert_close( - grad_wigner_back_triton, - grad_wigner_back_ref, - atol=1.0e-5, - rtol=1.0e-5, - ) + generator=generator, + ) + return wigner + + +def _block_mask(lmax, device): + dim = get_so3_dim_of_lmax(lmax) + mask = torch.zeros(dim, dim, dtype=torch.bool, device=device) + for ll in range(lmax + 1): + start, end = ll * ll, (ll + 1) ** 2 + mask[start:end, start:end] = True + return mask + + +@unittest.skipIf(not _CUDA, "CUDA is required for the Triton rotation kernels") +@unittest.skipIf(not TRITON_ROTATION_AVAILABLE, "Triton is not available") +class TestSeZMTritonRotation(unittest.TestCase): + def setUp(self): + self.device = torch.device("cuda") + self.dtype = torch.float32 + self.n_node, self.n_edge, self.channels = 64, 2000, 16 + self.tol = {"rtol": 2e-4, "atol": 2e-4} + + def _inputs(self, lmax, seed): + gen = torch.Generator(device=self.device).manual_seed(seed) + dim = get_so3_dim_of_lmax(lmax) + coeff_index = build_m_major_index(lmax, 1, device=self.device) + x = torch.randn( + self.n_node, + dim, + self.channels, + device=self.device, + dtype=self.dtype, + generator=gen, + ) + src = torch.randint( + 0, self.n_node, (self.n_edge,), device=self.device, generator=gen + ) + wigner = _block_diagonal_wigner(self.n_edge, lmax, self.device, self.dtype, gen) + return x, src, wigner, coeff_index, dim + + def test_rotate_to_local_matches_reference(self): + for lmax in (2, 3, 4, 5): + with self.subTest(lmax=lmax): + x0, src, w0, coeff_index, dim = self._inputs(lmax, seed=lmax) + mask = _block_mask(lmax, self.device) + + xa = x0.clone().requires_grad_(True) + wa = w0.clone().requires_grad_(True) + out = rotate_to_local(xa, src, wa, coeff_index, dim) + xr = x0.clone().requires_grad_(True) + wr = w0.clone().requires_grad_(True) + ref = rotate_to_local_reference(xr, src, wr, coeff_index, dim) + + torch.testing.assert_close(out, ref, **self.tol) + + grad_out = torch.randn_like(ref) + gxa, gwa = torch.autograd.grad( + out, [xa, wa], grad_out, retain_graph=True + ) + gxr, gwr = torch.autograd.grad(ref, [xr, wr], grad_out) + torch.testing.assert_close(gxa, gxr, **self.tol) + torch.testing.assert_close(gwa[:, mask], gwr[:, mask], **self.tol) + # The kernel never writes off-block Wigner gradient entries. + self.assertEqual(float(gwa[:, ~mask].abs().max()), 0.0) + + def test_rotate_back_matches_reference(self): + for lmax in (2, 3, 4, 5): + with self.subTest(lmax=lmax): + _, _, w0, coeff_index, dim = self._inputs(lmax, seed=lmax) + reduced = int(coeff_index.numel()) + gen = torch.Generator(device=self.device).manual_seed(100 + lmax) + xl0 = torch.randn( + self.n_edge, + reduced, + self.channels, + device=self.device, + dtype=self.dtype, + generator=gen, + ) + mask = _block_mask(lmax, self.device) + + xa = xl0.clone().requires_grad_(True) + wa = w0.clone().requires_grad_(True) + out = rotate_back(xa, wa, coeff_index, dim) + xr = xl0.clone().requires_grad_(True) + wr = w0.clone().requires_grad_(True) + ref = rotate_back_reference(xr, wr, coeff_index, dim) + + torch.testing.assert_close(out, ref, **self.tol) + + grad_out = torch.randn_like(ref) + gxa, gwa = torch.autograd.grad( + out, [xa, wa], grad_out, retain_graph=True + ) + gxr, gwr = torch.autograd.grad(ref, [xr, wr], grad_out) + torch.testing.assert_close(gxa, gxr, **self.tol) + torch.testing.assert_close(gwa[:, mask], gwr[:, mask], **self.tol) + + def test_torch_compile_composability(self): + """Gradients through the op match between eager and compiled tracing.""" + lmax = 3 + x0, src, w0, coeff_index, dim = self._inputs(lmax, seed=7) + weight = torch.randn_like( + rotate_to_local_reference(x0, src, w0, coeff_index, dim) + ) + mask = _block_mask(lmax, self.device) + + def scalar_output(x, wigner): + return (rotate_to_local(x, src, wigner, coeff_index, dim) * weight).sum() + + xe = x0.clone().requires_grad_(True) + we = w0.clone().requires_grad_(True) + gxe, gwe = torch.autograd.grad(scalar_output(xe, we), [xe, we]) + + compiled = torch.compile(scalar_output, dynamic=True) + xc = x0.clone().requires_grad_(True) + wc = w0.clone().requires_grad_(True) + gxc, gwc = torch.autograd.grad(compiled(xc, wc), [xc, wc]) + + torch.testing.assert_close(gxc, gxe, **self.tol) + # Also check the Wigner gradient (nonzero on the block entries) survives + # tracing, since it flows through the custom op's registered backward. + torch.testing.assert_close(gwc[:, mask], gwe[:, mask], **self.tol) + self.assertGreater(float(gwe[:, mask].abs().max()), 0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_sezm_model.py b/source/tests/pt/model/test_sezm_model.py index 15175a5adc..5abbf55173 100644 --- a/source/tests/pt/model/test_sezm_model.py +++ b/source/tests/pt/model/test_sezm_model.py @@ -586,7 +586,6 @@ def test_fixed_edge_geometry_matches_standard_cache(self) -> None: n_radial=descriptor.radial_basis.n_radial, random_gamma=False, wigner_calc=descriptor.wigner_calc, - use_geometry_rbf_triton=False, ) edge_index, edge_vec, edge_mask = model.build_edge_list_from_nlist( From c681e5a987ce22f14833992c63d16fff7ff4155a Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 7 Jun 2026 22:44:54 +0800 Subject: [PATCH 04/18] fixup --- backend/find_pytorch.py | 10 +- deepmd/pt/infer/deep_eval.py | 8 +- deepmd/pt/model/descriptor/sezm.py | 28 +- deepmd/pt/model/descriptor/sezm_nn/block.py | 3 +- .../descriptor/sezm_nn/cute/so2_rotation.py | 8 +- .../pt/model/descriptor/sezm_nn/edge_cache.py | 36 +- deepmd/pt/model/descriptor/sezm_nn/so2.py | 49 ++- .../descriptor/sezm_nn/triton/__init__.py | 18 +- .../descriptor/sezm_nn/triton/so2_rotation.py | 160 +++---- deepmd/pt/train/training.py | 39 +- .../test_descriptor_sezm_grid_projection.py | 2 +- .../pt/model/test_descriptor_sezm_triton.py | 400 ++++++++++++++---- source/tests/pt/model/test_nv_nlist.py | 91 ++-- 13 files changed, 564 insertions(+), 288 deletions(-) diff --git a/backend/find_pytorch.py b/backend/find_pytorch.py index 48361d86af..9961977609 100644 --- a/backend/find_pytorch.py +++ b/backend/find_pytorch.py @@ -142,9 +142,13 @@ def get_pt_requirement(pt_version: str = "") -> dict: # under the torch extra rather than the core deps (conda-forge has # vesin but not vesin-torch). "vesin[torch]", - # GPU O(N) cell-list neighbor list for large systems; the package - # requires Python >= 3.11 while deepmd-kit still supports 3.10. - "nvalchemi-toolkit-ops>=0.3.1; python_version >= '3.11'", + # GPU O(N) cell-list neighbor list for large systems. Restricted to + # Python >= 3.11 (the package requires it while deepmd-kit still + # supports 3.10) and to Linux: it is a CUDA package, and its + # dependency warp-lang ships no macosx_x86_64 wheel, which otherwise + # makes the macOS x86_64 wheel build's dependency resolution + # unsatisfiable. + "nvalchemi-toolkit-ops>=0.3.1; python_version >= '3.11' and platform_system == 'Linux'", *mpi_requirement, *cibw_requirement, ], diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 640b5fa918..7d54c7ef01 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -325,6 +325,12 @@ def _setup_nlist_backend(self, nlist_backend: str) -> None: "vesin[torch]`) or use nlist_backend='native' (or 'auto')." ) builder = VesinNeighborList() + elif DEVICE.type != "cuda": + raise ValueError( + "nlist_backend='nv' requires CUDA inference tensors; " + f"current DEVICE is {DEVICE!s}. Use nlist_backend='native' " + "(or 'auto') for CPU inference." + ) elif not is_nv_available(): raise ImportError( "nlist_backend='nv' was requested but 'nvalchemi-toolkit-ops'" @@ -338,7 +344,7 @@ def _setup_nlist_backend(self, nlist_backend: str) -> None: # Pick the first available O(N) builder; nv is GPU-only. if is_vesin_torch_available(): builder = VesinNeighborList() - elif is_nv_available() and torch.cuda.is_available(): + elif is_nv_available() and DEVICE.type == "cuda": builder = NvNeighborList() self._nlist_builder = builder diff --git a/deepmd/pt/model/descriptor/sezm.py b/deepmd/pt/model/descriptor/sezm.py index b20e6666a0..ba3ef38e65 100644 --- a/deepmd/pt/model/descriptor/sezm.py +++ b/deepmd/pt/model/descriptor/sezm.py @@ -619,7 +619,7 @@ def __init__( self.kmax = int(kmax) if self.kmax < 0: raise ValueError("`kmax` must be non-negative") - if self.kmax > int(lmax): + if self.kmax > self.lmax: raise ValueError("`kmax` must be <= `lmax`") self.ebed_dims = [get_so3_dim_of_lmax(l) for l in self.l_schedule] self._init_node_l_schedules(extra_node_l) @@ -844,8 +844,14 @@ def __init__( blocks: list[SeZMInteractionBlock] = [] for block_idx, (l_b, node_l_b, m_b) in enumerate( - zip(self.l_schedule, self.node_l_schedule, self.m_schedule) + zip( + self.l_schedule, + self.node_l_schedule, + self.m_schedule, + strict=True, + ) ): + k_b = min(self.kmax, l_b) blocks.append( SeZMInteractionBlock( lmax=l_b, @@ -877,7 +883,7 @@ def __init__( message_node_so3=self.message_node_so3, ffn_s2_activation=self.ffn_s2_activation, ffn_so3_grid=self.ffn_so3_grid, - kmax=self.kmax, + kmax=k_b, so2_lebedev_quadrature=self.so2_lebedev_quadrature, ffn_lebedev_quadrature=self.ffn_lebedev_quadrature, n_atten_head=self.n_atten_head, @@ -1511,12 +1517,14 @@ def _build_gie_zonal_coupling( mp_row_index, mp_m0_col_index, ] - edge_len = safe_norm(edge_cache.edge_vec, self.eps) - edge_quat = build_edge_quaternion( - edge_cache.edge_vec, - edge_len=edge_len, - eps=self.eps, - ) + edge_quat = edge_cache.edge_quat + if edge_quat is None: + edge_len = safe_norm(edge_cache.edge_vec, self.eps) + edge_quat = build_edge_quaternion( + edge_cache.edge_vec, + edge_len=edge_len, + eps=self.eps, + ) extra_coupling = self.gie_zonal_wigner_calc.forward_zonal( edge_quat, lmin=self.lmax + 1, @@ -1676,7 +1684,7 @@ def _init_lm_schedules( raise ValueError("`m_schedule` must have the same length as `l_schedule`") if any(x < 0 for x in self.m_schedule): raise ValueError("`m_schedule` entries must be non-negative") - if any(m > l for m, l in zip(self.m_schedule, self.l_schedule)): + if any(m > l for m, l in zip(self.m_schedule, self.l_schedule, strict=True)): raise ValueError( "`m_schedule` entries must satisfy `m_schedule[i] <= l_schedule[i]`" ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/block.py b/deepmd/pt/model/descriptor/sezm_nn/block.py index bc19d53cc6..433eb000bf 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/block.py +++ b/deepmd/pt/model/descriptor/sezm_nn/block.py @@ -638,7 +638,8 @@ def _use_infer_activation_checkpoint(self, *tensors: torch.Tensor) -> bool: """ return ( not self.training - and os.environ.get("DP_ACT_INFER") == "1" + and os.environ.get("DP_ACT_INFER", "").strip().lower() + in {"1", "true", "yes", "on"} and os.environ.get("DP_COMPILE_INFER", "").strip().lower() not in {"1", "true", "yes", "on"} and torch.is_grad_enabled() diff --git a/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py b/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py index f7bf36c743..9af65aaaac 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py +++ b/deepmd/pt/model/descriptor/sezm_nn/cute/so2_rotation.py @@ -830,10 +830,12 @@ def _rb_backward(ctx: Any, grad_out: Tensor) -> tuple: # === Public API ============================================================== -def _cute_usable(*tensors: Tensor) -> bool: +def _cute_usable(channels: int, *tensors: Tensor) -> bool: """Return True when the CuTe fast path is available for these tensors.""" if not SEZM_CUTE_AVAILABLE: return False + if int(channels) < _TN or int(channels) % _TN != 0: + return False return all( t.is_cuda and t.dtype == torch.float32 for t in tensors if t.is_floating_point() ) @@ -872,7 +874,7 @@ def rotate_to_local_cute( Experimental path that is not used in production. See the module docstring for the benchmark conclusion and why the Triton kernels were chosen instead. """ - if _cute_usable(x, wigner) and src.numel() > 0: + if _cute_usable(x.shape[2], x, wigner) and src.numel() > 0: return torch.ops.sezm_cute.rotate_to_local( x, src, wigner, coeff_index, int(dim_full) ) @@ -909,7 +911,7 @@ def rotate_back_cute( Experimental path that is not used in production. See the module docstring for the benchmark conclusion and why the Triton kernels were chosen instead. """ - if _cute_usable(x_local, wigner) and x_local.shape[0] > 0: + if _cute_usable(x_local.shape[2], x_local, wigner) and x_local.shape[0] > 0: return torch.ops.sezm_cute.rotate_back( x_local, wigner, coeff_index, int(dim_full) ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py b/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py index 2174dfd4b3..0545c82b76 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py +++ b/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py @@ -71,6 +71,9 @@ class EdgeFeatureCache(NamedTuple): Used for efficient batched rotation. None if not available. Dt_full Transpose of D_full with shape (E, D, D). None if not available. + edge_quat + Per-edge global-to-local quaternion actually used to build ``D_full`` and + ``Dt_full`` with shape (E, 4). Includes the optional random local-Z roll. D_to_m_cache Lazy cache for projected D matrices keyed by a normalized ``"lmax:mmax"`` identifier. @@ -103,6 +106,7 @@ class EdgeFeatureCache(NamedTuple): D_to_m_cache: dict[str, torch.Tensor] | None = None Dt_from_m_cache: dict[str, torch.Tensor] | None = None edge_src_gate: torch.Tensor | None = None + edge_quat: torch.Tensor | None = None def compute_edge_src_gate( @@ -337,13 +341,13 @@ def build_edge_cache( # === Step 6. Edge quaternion -> Wigner-D blocks === with nvtx_range("wigner_d"): - D_full, Dt_full = _build_edge_wigner( + D_full, Dt_full, edge_quat = _build_edge_wigner( edge_vec=edge_vec, edge_len=edge_len, eps=eps, random_gamma=random_gamma, wigner_calc=wigner_calc, - ) # (E, D, D), (E, D, D) + ) # (E, D, D), (E, D, D), (E, 4) edge_type_feat = build_edge_type_feat(type_ebed, src, dst) # (E, C) @@ -357,6 +361,7 @@ def build_edge_cache( edge_env=edge_env, D_full=D_full, Dt_full=Dt_full, + edge_quat=edge_quat, deg_norm_floor=deg_norm_floor, ) @@ -460,13 +465,13 @@ def build_edge_cache_from_edges( # === Step 4. Edge quaternion -> Wigner-D blocks === with nvtx_range("wigner_d"): - D_full, Dt_full = _build_edge_wigner( + D_full, Dt_full, edge_quat = _build_edge_wigner( edge_vec=edge_vec, edge_len=edge_len, eps=eps, random_gamma=random_gamma, wigner_calc=wigner_calc, - ) # (E, D, D), (E, D, D) + ) # (E, D, D), (E, D, D), (E, 4) # === Step 5. Edge type features === edge_type_feat = build_edge_type_feat(type_ebed, src, dst) @@ -498,6 +503,7 @@ def build_edge_cache_from_edges( edge_env=edge_env, D_full=D_full, Dt_full=Dt_full, + edge_quat=edge_quat, deg_norm_floor=deg_norm_floor, edge_src_gate=edge_src_gate, ) @@ -510,7 +516,7 @@ def _build_edge_wigner( eps: float, random_gamma: bool, wigner_calc: WignerCalculatorFn, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Build packed Wigner-D blocks from edge vectors. @@ -530,8 +536,9 @@ def _build_edge_wigner( Returns ------- - tuple[torch.Tensor, torch.Tensor] - Packed Wigner-D matrices ``(D_full, Dt_full)`` with shape ``(E, D, D)``. + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + Packed Wigner-D matrices ``(D_full, Dt_full)`` with shape ``(E, D, D)`` + and the quaternion used to build them with shape ``(E, 4)``. """ # === Step 1. Build edge-aligned quaternions === edge_quat = build_edge_quaternion( @@ -550,7 +557,8 @@ def _build_edge_wigner( edge_quat = quaternion_multiply(quaternion_z_rotation(gamma), edge_quat) # === Step 3. Convert quaternions to packed Wigner-D blocks === - return wigner_calc(edge_quat) + D_full, Dt_full = wigner_calc(edge_quat) + return D_full, Dt_full, edge_quat def _finalize_edge_cache( @@ -564,6 +572,7 @@ def _finalize_edge_cache( edge_env: torch.Tensor, D_full: torch.Tensor, Dt_full: torch.Tensor, + edge_quat: torch.Tensor, deg_norm_floor: float, edge_src_gate: torch.Tensor | None = None, ) -> EdgeFeatureCache: @@ -590,6 +599,9 @@ def _finalize_edge_cache( Packed Wigner-D matrices with shape (E, D, D). Dt_full Transposed packed Wigner-D matrices with shape (E, D, D). + edge_quat + Global-to-local quaternions used to build the Wigner-D matrices with + shape (E, 4). deg_norm_floor Floor added to the envelope-squared degree before the inverse-sqrt normalization. A tiny ``eps`` reproduces the legacy behavior; an @@ -627,6 +639,7 @@ def _finalize_edge_cache( D_to_m_cache={}, Dt_from_m_cache={}, edge_src_gate=edge_src_gate, + edge_quat=edge_quat, ) @@ -661,6 +674,7 @@ def _get_empty_edge_cache( """ empty_long = torch.empty(0, dtype=torch.long, device=device) empty_vec = torch.empty(0, 3, dtype=dtype, device=device) + empty_quat = torch.empty(0, 4, dtype=dtype, device=device) empty_rbf = torch.empty(0, n_radial, dtype=dtype, device=device) empty_type_feat = torch.empty(0, n_channel, dtype=dtype, device=device) deg = torch.zeros(n_nodes, dtype=dtype, device=device) @@ -679,6 +693,7 @@ def _get_empty_edge_cache( D_to_m_cache={}, Dt_from_m_cache={}, edge_src_gate=None, + edge_quat=empty_quat, ) @@ -835,15 +850,19 @@ def edge_cache_to_dtype( _D_full = cache.D_full _Dt_full = cache.Dt_full _edge_src_gate = cache.edge_src_gate + _edge_quat = cache.edge_quat D_full: torch.Tensor | None = None Dt_full: torch.Tensor | None = None edge_src_gate: torch.Tensor | None = None + edge_quat: torch.Tensor | None = None if _D_full is not None: D_full = _D_full.to(dtype=dtype) if _Dt_full is not None: Dt_full = _Dt_full.to(dtype=dtype) if _edge_src_gate is not None: edge_src_gate = _edge_src_gate.to(dtype=dtype) + if _edge_quat is not None: + edge_quat = _edge_quat.to(dtype=dtype) return EdgeFeatureCache( src=cache.src, @@ -859,4 +878,5 @@ def edge_cache_to_dtype( D_to_m_cache=None if cache.D_to_m_cache is None else {}, Dt_from_m_cache=None if cache.Dt_from_m_cache is None else {}, edge_src_gate=edge_src_gate, + edge_quat=edge_quat, ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py index a36cb25809..84bd43c966 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/so2.py +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -71,10 +71,6 @@ FocusLinear, SO3Linear, ) -from .triton import ( - rotate_back, - rotate_to_local, -) from .utils import ( ATTN_RES_MODES, get_promoted_dtype, @@ -958,6 +954,23 @@ def __init__( self.use_triton_infer = os.environ.get( "DP_TRITON_INFER", "0" ).strip().lower() in ("1", "true", "yes", "on") + # Triton rotation kernels: block for the mmax == 1 layout, dense otherwise. + self._rotate_to_local_fn = None + self._rotate_back_fn = None + if self.use_triton_infer: + from .triton import ( + rotate_back_block, + rotate_back_dense, + rotate_to_local_block, + rotate_to_local_dense, + ) + + if self.mmax == 1: + self._rotate_to_local_fn = rotate_to_local_block + self._rotate_back_fn = rotate_back_block + else: + self._rotate_to_local_fn = rotate_to_local_dense + self._rotate_back_fn = rotate_back_dense # === Step 1. Precompute coefficient indices for m-major reduced layout === coeff_index_m = build_m_major_index(self.lmax, self.mmax, device=self.device) @@ -1334,6 +1347,7 @@ def __init__( dtype=self.compute_dtype, layout="flat", grid_resolution_list=self.s2_grid_resolution, + coefficient_layout="m_major", grid_method=self.s2_grid_method, grid_branches=node_wise_branches, mlp_bias=self.mlp_bias, @@ -1443,26 +1457,15 @@ def forward( D_full = edge_cache.D_full x_dst_local: torch.Tensor | None = None if self.use_triton_infer and not self.training: - # ``rotate_to_local`` / ``rotate_back`` pick the kernel from the - # coefficient layout, not from this flag: the block-diagonal - # kernel for the canonical m-major ``mmax == 1`` layout used here - # (with ``lmax`` inferred from ``ebed_dim_full``), and a dense - # kernel for any other ``(lmax, mmax)``. Both compose with the - # traced force path through their functional custom-op autograd. - x_local = rotate_to_local( - x, - src, - D_full, - self.coeff_index_m, - self.ebed_dim_full, + # ``self._rotate_to_local_fn`` was bound in ``__init__`` (the + # block kernel for the m-major ``mmax == 1`` layout, dense + # otherwise). + x_local = self._rotate_to_local_fn( + x, src, D_full, self.coeff_index_m, self.ebed_dim_full ) # (E, D_m, C_wide) if self.node_wise_grid_product is not None: - x_dst_local = rotate_to_local( - x, - dst, - D_full, - self.coeff_index_m, - self.ebed_dim_full, + x_dst_local = self._rotate_to_local_fn( + x, dst, D_full, self.coeff_index_m, self.ebed_dim_full ) # (E, D_m, C_wide) else: D_m_prime = project_D_to_m( @@ -1620,7 +1623,7 @@ def apply_bias_correction( with nvtx_range("SO2Conv/rotate_back"): Dt_full = edge_cache.Dt_full if self.use_triton_infer and not self.training: - x_message = rotate_back( + x_message = self._rotate_back_fn( x_local, Dt_full, self.coeff_index_m, diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py index eec36931e7..001956f9c1 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py @@ -1,17 +1,21 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Hardware-accelerated SeZM/DPA4 operators. -This package hosts clean, ``torch.compile``-composable Triton implementations of -SeZM hot paths. The first member is the fused SO(2)/Wigner rotation pair used by -the SO(2) convolution (``rotate_to_local`` / ``rotate_back``). +This package hosts ``make_fx``-composable Triton implementations of SeZM hot +paths. The SO(2) rotation API exposes a general dense path that honors arbitrary +coefficient indices and a block path for the canonical m-major ``mmax=1`` layout. """ from .so2_rotation import ( - rotate_back, - rotate_to_local, + rotate_back_block, + rotate_back_dense, + rotate_to_local_block, + rotate_to_local_dense, ) __all__ = [ - "rotate_back", - "rotate_to_local", + "rotate_back_block", + "rotate_back_dense", + "rotate_to_local_block", + "rotate_to_local_dense", ] diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py b/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py index 258159cec0..d0e01694ff 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py @@ -39,21 +39,19 @@ ``input_precision="ieee"`` so the contraction runs in true IEEE FP32 (no TF32). This keeps the potential-energy surface smooth. -3. **Compose with ``torch.compile`` (correct forces).** The public ops are - *modern* functional ``torch.library.custom_op`` s (``mutates_args=()``) with - ``register_fake`` + ``register_autograd``. The backward is itself a pair of - functional custom ops. We never use ``torch.autograd.Function`` and never - mutate an input/output tensor in place at the Python level. This is what - makes the gradient w.r.t. ``wigner`` survive ``make_fx`` functionalization - (the legacy ``autograd.Function`` + in-place path drops it, producing wrong - forces under ``torch.compile``). +3. **Compose with SeZM's ``make_fx`` lowering.** The operators are functional + ``torch.library.custom_op`` instances (``mutates_args=()``) with registered + fake kernels and autograd formulas. The backward is itself expressed as + functional custom ops, so ``make_fx(tracing_mode="symbolic")`` can capture the + energy path together with the force autograd graph used by inference. Shapes / dtypes --------------- -``x``/``x_local`` and ``wigner`` are float (fp32 is the supported precision for -the smooth PES; fp16/bf16 also run but accumulate in fp32). ``src`` and -``coeff_index`` are int64. ``E`` (edges) may exceed 2**31 elements once -multiplied by the per-edge matrix size, so all kernels use int64 addressing. +``x``/``x_local`` and ``wigner`` are float tensors; fp32 is the supported +precision for the smooth potential-energy surface, while fp16/bf16 inputs +accumulate in fp32. ``src`` and ``coeff_index`` are int64 tensors. ``E`` (edges) +may exceed 2**31 elements once multiplied by the per-edge matrix size, so all +kernels use int64 addressing. """ from __future__ import ( @@ -73,11 +71,9 @@ __all__ = [ "TRITON_ROTATION_AVAILABLE", - "rotate_back", "rotate_back_block", "rotate_back_dense", "rotate_back_reference", - "rotate_to_local", "rotate_to_local_block", "rotate_to_local_dense", "rotate_to_local_reference", @@ -1169,13 +1165,9 @@ def _block_layout_lmax(coeff_index: Tensor, dim_full: int) -> int: """Return ``lmax`` if ``(coeff_index, dim_full)`` is the m-major ``mmax=1`` layout that the block-diagonal kernels assume, else ``-1``. - Detection uses ONLY shapes / python ints -- never tensor *values* -- so it is - safe under ``make_fx`` / fake-tensor tracing (the production compiled - inference path). The test is: ``dim_full`` is a perfect square ``(lmax+1)^2`` - and ``Dm == 3*lmax+1``. For a fixed ``lmax`` the reduced size ``Dm`` is - strictly increasing in ``mmax`` (``lmax+1``, ``3*lmax+1``, ``5*lmax-1``, ...), - so ``Dm == 3*lmax+1`` uniquely pins ``mmax == 1``; combined with the model's - canonical ``build_m_major_index`` ordering this fully determines the layout. + This intentionally checks only shape-level invariants. The block kernels + ignore ``coeff_index`` values, so production callers must only use the block + entry points when they own the canonical m-major ``mmax=1`` index. """ dim_full = int(dim_full) root = math.isqrt(dim_full) @@ -1455,19 +1447,19 @@ def _block_rotate_back_bwd_impl( # ``wigner`` -- intact under ``torch.compile``. _rotate_to_local_op = torch.library.custom_op( - "sezm_accel::rotate_to_local", mutates_args=() + "sezm_triton::rotate_to_local", mutates_args=() )(_rotate_to_local_impl) _rotate_to_local_bwd_op = torch.library.custom_op( - "sezm_accel::rotate_to_local_bwd", mutates_args=() + "sezm_triton::rotate_to_local_bwd", mutates_args=() )(_rotate_to_local_bwd_impl) -_rotate_back_op = torch.library.custom_op("sezm_accel::rotate_back", mutates_args=())( +_rotate_back_op = torch.library.custom_op("sezm_triton::rotate_back", mutates_args=())( _rotate_back_impl ) _rotate_back_bwd_op = torch.library.custom_op( - "sezm_accel::rotate_back_bwd", mutates_args=() + "sezm_triton::rotate_back_bwd", mutates_args=() )(_rotate_back_bwd_impl) @@ -1529,19 +1521,19 @@ def _rotate_back_backward(ctx, grad_out): # --- block-diagonal custom ops (carry only ``lmax``; no coeff_index tensor) --- _block_to_local_op = torch.library.custom_op( - "sezm_accel::rotate_to_local_block", mutates_args=() + "sezm_triton::rotate_to_local_block", mutates_args=() )(_block_rotate_to_local_impl) _block_to_local_bwd_op = torch.library.custom_op( - "sezm_accel::rotate_to_local_block_bwd", mutates_args=() + "sezm_triton::rotate_to_local_block_bwd", mutates_args=() )(_block_rotate_to_local_bwd_impl) _block_back_op = torch.library.custom_op( - "sezm_accel::rotate_back_block", mutates_args=() + "sezm_triton::rotate_back_block", mutates_args=() )(_block_rotate_back_impl) _block_back_bwd_op = torch.library.custom_op( - "sezm_accel::rotate_back_block_bwd", mutates_args=() + "sezm_triton::rotate_back_block_bwd", mutates_args=() )(_block_rotate_back_bwd_impl) @@ -1600,101 +1592,43 @@ def _block_back_backward(ctx, grad_out): # ====================================================================== # Public API # ====================================================================== -def rotate_to_local( - x: Tensor, - src: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - """Fused ``global -> edge-local reduced`` rotation ``bmm(D_to_m, x[src])``. - - Parameters - ---------- - x - Node features, shape ``(N, D, C)``. - src - Source-node index per edge, shape ``(E,)`` int64. - wigner - Per-edge packed Wigner-D matrices, shape ``(E, Dw, Dw)`` with - ``Dw >= dim_full``. - coeff_index - m-major reduced-layout row indices, shape ``(Dm,)`` int64. - dim_full - Full packed SO(3) dimension ``D = (lmax+1)**2``. - - Returns - ------- - Tensor - Rotated reduced-layout edge features, shape ``(E, Dm, C)``. - - Notes - ----- - When ``(coeff_index, dim_full)`` is the m-major ``mmax=1`` layout, this - auto-selects the block-diagonal kernel (which assumes a block-diagonal - Wigner-D, as produced by the model); otherwise it uses the dense kernel. - """ - lmax = _block_layout_lmax(coeff_index, dim_full) - if lmax >= 0: - return _block_to_local_op(x, src, wigner, lmax) - return _rotate_to_local_op(x, src, wigner, coeff_index, int(dim_full)) - - -def rotate_back( - x_local: Tensor, - wigner: Tensor, - coeff_index: Tensor, - dim_full: int, -) -> Tensor: - """Fused ``edge-local reduced -> global`` rotation ``bmm(Dt_from_m, x_local)``. - - Parameters - ---------- - x_local - Reduced-layout edge features, shape ``(E, Dm, C)``. - wigner - Per-edge packed Wigner-D matrices, shape ``(E, Dw, Dw)`` with - ``Dw >= dim_full``. - coeff_index - m-major reduced-layout column indices, shape ``(Dm,)`` int64. - dim_full - Full packed SO(3) dimension ``D = (lmax+1)**2``. - - Returns - ------- - Tensor - Lifted global-layout edge features, shape ``(E, D, C)``. - - Notes - ----- - Auto-selects the block-diagonal kernel for the m-major ``mmax=1`` layout, - else the dense kernel (see ``rotate_to_local``). - """ - lmax = _block_layout_lmax(coeff_index, dim_full) - if lmax >= 0: - return _block_back_op(x_local, wigner, lmax) - return _rotate_back_op(x_local, wigner, coeff_index, int(dim_full)) - - -# --- Explicit entry points (benchmarking / forcing a path) --- +# --- Public entry points ----------------------------------------------------- def rotate_to_local_dense( x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int ) -> Tensor: - """Force the dense (general-layout) ``rotate_to_local`` kernel.""" + """Apply the general ``global -> local`` rotation. + + This entry point honors every value in ``coeff_index`` and supports any + reduced coefficient layout. It computes the same operation as + ``rotate_to_local_reference`` while avoiding materialized gather operands on + CUDA. + """ return _rotate_to_local_op(x, src, wigner, coeff_index, int(dim_full)) def rotate_back_dense( x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int ) -> Tensor: - """Force the dense (general-layout) ``rotate_back`` kernel.""" + """Apply the general ``local -> global`` rotation. + + This entry point honors every value in ``coeff_index`` and supports any + reduced coefficient layout. It computes the same operation as + ``rotate_back_reference`` while avoiding materialized gather operands on + CUDA. + """ return _rotate_back_op(x_local, wigner, coeff_index, int(dim_full)) def rotate_to_local_block( x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int ) -> Tensor: - """Force the block-diagonal ``rotate_to_local`` kernel (requires mmax=1 layout).""" + """Apply the block-diagonal ``global -> local`` rotation. + + Use this only when the caller owns the invariant that ``coeff_index`` is the + canonical m-major ``mmax=1`` index produced by + :func:`build_m_major_index`. The kernel ignores the tensor values in + ``coeff_index`` and derives the layout from ``lmax``. + """ lmax = _block_layout_lmax(coeff_index, dim_full) if lmax < 0: raise ValueError( @@ -1706,7 +1640,13 @@ def rotate_to_local_block( def rotate_back_block( x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int ) -> Tensor: - """Force the block-diagonal ``rotate_back`` kernel (requires mmax=1 layout).""" + """Apply the block-diagonal ``local -> global`` rotation. + + Use this only when the caller owns the invariant that ``coeff_index`` is the + canonical m-major ``mmax=1`` index produced by + :func:`build_m_major_index`. The kernel ignores the tensor values in + ``coeff_index`` and derives the layout from ``lmax``. + """ lmax = _block_layout_lmax(coeff_index, dim_full) if lmax < 0: raise ValueError( diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 97ae63589a..48e550729f 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -11,6 +11,7 @@ Iterable, ) from contextlib import ( + contextmanager, nullcontext, ) from copy import ( @@ -147,6 +148,22 @@ log = logging.getLogger(__name__) +@contextmanager +def _scoped_env_defaults(defaults: dict[str, str]) -> Generator[None, None, None]: + """Temporarily set missing environment variables and restore them afterward.""" + previous = {key: os.environ.get(key) for key in defaults} + try: + for key, value in defaults.items(): + os.environ.setdefault(key, value) + yield + finally: + for key, value in previous.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + class Trainer: def __init__( self, @@ -182,14 +199,11 @@ def __init__( optimizer_params = config.get("optimizer", {}) validating_params = config.get("validating") or {} - # NOTE: Translate eval/inference options from input.json into - # environment variables before any model is constructed below. - # SeZMModel samples these env vars exactly once inside its __init__. - # ``setdefault`` preserves explicit shell-level overrides. + infer_env_defaults = {} if bool(validating_params.get("compiled_infer", False)): - os.environ.setdefault("DP_COMPILE_INFER", "1") + infer_env_defaults["DP_COMPILE_INFER"] = "1" if bool(validating_params.get("tf32_infer", False)): - os.environ.setdefault("DP_TF32_INFER", "1") + infer_env_defaults["DP_TF32_INFER"] = "1" self.multi_task = "model_dict" in model_params self.finetune_links = finetune_links self.finetune_update_stat = False @@ -446,11 +460,14 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: } # Model - self.model = get_model_for_wrapper( - model_params, - resuming=resuming, - _loss_params=loss_param_tmp, - ) + # SeZMModel samples these eval/inference env vars exactly once inside + # __init__; keep config-derived defaults scoped to construction. + with _scoped_env_defaults(infer_env_defaults): + self.model = get_model_for_wrapper( + model_params, + resuming=resuming, + _loss_params=loss_param_tmp, + ) # SeZM specific process for DeNS training prepare_model_for_loss(self.model, loss_param_tmp) diff --git a/source/tests/pt/model/test_descriptor_sezm_grid_projection.py b/source/tests/pt/model/test_descriptor_sezm_grid_projection.py index 9b1a0f6c6f..602d3d47a6 100644 --- a/source/tests/pt/model/test_descriptor_sezm_grid_projection.py +++ b/source/tests/pt/model/test_descriptor_sezm_grid_projection.py @@ -105,7 +105,7 @@ def test_default_full_m_grid_counts_keep_s2_activation_equivariant(self) -> None (4, [14, 14], 9.20e-7, 5.00e-6), # local: fp64=7.97e-7, fp32=1.55e-6 (5, [18, 18], 1.70e-6, 5.00e-6), # local: fp64=1.48e-6, fp32=1.49e-6 (6, [20, 20], 4.80e-6, 5.00e-6), # local: fp64=4.14e-6, fp32=2.27e-6 - (7, [24, 24], 3.70e-6, 5.00e-6), # local: fp64=3.19e-6, fp32=2.03e-6 + (7, [24, 24], 3.70e-6, 6.00e-6), # local: fp64=3.19e-6, fp32=2.03e-6 ], "lebedev": [ (2, [7, 26], 1.00e-12, 5.00e-6), # local: fp64=2.31e-14, fp32=2.38e-7 diff --git a/source/tests/pt/model/test_descriptor_sezm_triton.py b/source/tests/pt/model/test_descriptor_sezm_triton.py index e09794e32f..05c88216a4 100644 --- a/source/tests/pt/model/test_descriptor_sezm_triton.py +++ b/source/tests/pt/model/test_descriptor_sezm_triton.py @@ -11,15 +11,19 @@ the reference on the block entries (the off-block reference gradient is structurally discarded by the model, which builds the Wigner with zero off-block entries). -2. ``torch.compile`` composability: gradients through the functional - ``custom_op`` must match the eager reference when the op is traced under - ``make_fx`` -- the autograd path that compiled inference uses to obtain - forces. +2. ``make_fx(tracing_mode="symbolic")`` composability: the traced graph contains + both the rotation forward and the autograd graph used by inference forces. + This mirrors the SeZM inference path, which traces with ``make_fx`` before + lowering the resulting graph through AOTAutograd's forward-only compiler. """ +import math import unittest import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.pt.model.descriptor.sezm_nn.indexing import ( build_m_major_index, @@ -27,9 +31,11 @@ ) from deepmd.pt.model.descriptor.sezm_nn.triton.so2_rotation import ( TRITON_ROTATION_AVAILABLE, - rotate_back, + rotate_back_block, + rotate_back_dense, rotate_back_reference, - rotate_to_local, + rotate_to_local_block, + rotate_to_local_dense, rotate_to_local_reference, ) @@ -64,6 +70,76 @@ def _block_mask(lmax, device): return mask +class TestSeZMTritonRotationDispatch(unittest.TestCase): + def test_noncanonical_same_length_uses_dense_reference(self): + device = torch.device("cpu") + dtype = torch.float32 + lmax = 3 + dim = get_so3_dim_of_lmax(lmax) + canonical = build_m_major_index(lmax, 1, device=device) + coeff_index = torch.roll(canonical, shifts=1) + x = torch.randn(4, dim, 3, device=device, dtype=dtype) + src = torch.tensor([0, 2, 1, 3, 0], dtype=torch.long, device=device) + wigner = torch.randn(src.numel(), dim, dim, device=device, dtype=dtype) + x_local = torch.randn( + src.numel(), coeff_index.numel(), 3, device=device, dtype=dtype + ) + + torch.testing.assert_close( + rotate_to_local_dense(x, src, wigner, coeff_index, dim), + rotate_to_local_reference(x, src, wigner, coeff_index, dim), + ) + torch.testing.assert_close( + rotate_back_dense(x_local, wigner, coeff_index, dim), + rotate_back_reference(x_local, wigner, coeff_index, dim), + ) + + def test_explicit_block_uses_shape_contract_only(self): + device = torch.device("cpu") + dtype = torch.float32 + lmax = 3 + dim = get_so3_dim_of_lmax(lmax) + canonical = build_m_major_index(lmax, 1, device=device) + coeff_index = torch.roll(canonical, shifts=1) + x = torch.randn(4, dim, 3, device=device, dtype=dtype) + src = torch.tensor([0, 2, 1, 3, 0], dtype=torch.long, device=device) + wigner = torch.randn(src.numel(), dim, dim, device=device, dtype=dtype) + x_local = torch.randn( + src.numel(), coeff_index.numel(), 3, device=device, dtype=dtype + ) + + self.assertEqual( + rotate_to_local_block(x, src, wigner, coeff_index, dim).shape, x_local.shape + ) + self.assertEqual( + rotate_back_block(x_local, wigner, coeff_index, dim).shape, + (src.numel(), dim, 3), + ) + + def test_symbolic_trace_noncanonical_same_length_uses_dense_op(self): + device = torch.device("cpu") + dtype = torch.float32 + lmax = 3 + dim = get_so3_dim_of_lmax(lmax) + canonical = build_m_major_index(lmax, 1, device=device) + coeff_index = torch.roll(canonical, shifts=1) + x = torch.randn(4, dim, 3, device=device, dtype=dtype) + src = torch.tensor([0, 2, 1, 3, 0], dtype=torch.long, device=device) + wigner = torch.randn(src.numel(), dim, dim, device=device, dtype=dtype) + + def fn(x, src, wigner, coeff_index): + return rotate_to_local_dense(x, src, wigner, coeff_index, dim) + + graph_module = make_fx( + fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )(x, src, wigner, coeff_index) + graph_code = graph_module.code + self.assertIn("sezm_triton.rotate_to_local.default", graph_code) + self.assertNotIn("sezm_triton.rotate_to_local_block.default", graph_code) + + @unittest.skipIf(not _CUDA, "CUDA is required for the Triton rotation kernels") @unittest.skipIf(not TRITON_ROTATION_AVAILABLE, "Triton is not available") class TestSeZMTritonRotation(unittest.TestCase): @@ -91,90 +167,260 @@ def _inputs(self, lmax, seed): wigner = _block_diagonal_wigner(self.n_edge, lmax, self.device, self.dtype, gen) return x, src, wigner, coeff_index, dim - def test_rotate_to_local_matches_reference(self): + def _local_inputs(self, lmax, seed): + _, _, wigner, coeff_index, dim = self._inputs(lmax, seed=seed) + gen = torch.Generator(device=self.device).manual_seed(100 + seed) + x_local = torch.randn( + self.n_edge, + int(coeff_index.numel()), + self.channels, + device=self.device, + dtype=self.dtype, + generator=gen, + ) + return x_local, wigner, coeff_index, dim + + def _assert_to_local_matches_reference(self, x0, src, w0, coeff_index, dim): + lmax = math.isqrt(int(dim)) - 1 + mask = _block_mask(lmax, self.device) + + xa = x0.clone().requires_grad_(True) + wa = w0.clone().requires_grad_(True) + out = rotate_to_local_block(xa, src, wa, coeff_index, dim) + xr = x0.clone().requires_grad_(True) + wr = w0.clone().requires_grad_(True) + ref = rotate_to_local_reference(xr, src, wr, coeff_index, dim) + + torch.testing.assert_close(out, ref, **self.tol) + + grad_out = torch.randn_like(ref) + gxa, gwa = torch.autograd.grad(out, [xa, wa], grad_out, retain_graph=True) + gxr, gwr = torch.autograd.grad(ref, [xr, wr], grad_out) + torch.testing.assert_close(gxa, gxr, **self.tol) + torch.testing.assert_close(gwa[:, mask], gwr[:, mask], **self.tol) + self.assertEqual(float(gwa[:, ~mask].abs().max()), 0.0) + + def _assert_back_matches_reference(self, xl0, w0, coeff_index, dim): + lmax = math.isqrt(int(dim)) - 1 + mask = _block_mask(lmax, self.device) + + xa = xl0.clone().requires_grad_(True) + wa = w0.clone().requires_grad_(True) + out = rotate_back_block(xa, wa, coeff_index, dim) + xr = xl0.clone().requires_grad_(True) + wr = w0.clone().requires_grad_(True) + ref = rotate_back_reference(xr, wr, coeff_index, dim) + + torch.testing.assert_close(out, ref, **self.tol) + + grad_out = torch.randn_like(ref) + gxa, gwa = torch.autograd.grad(out, [xa, wa], grad_out, retain_graph=True) + gxr, gwr = torch.autograd.grad(ref, [xr, wr], grad_out) + torch.testing.assert_close(gxa, gxr, **self.tol) + torch.testing.assert_close(gwa[:, mask], gwr[:, mask], **self.tol) + self.assertEqual(float(gwa[:, ~mask].abs().max()), 0.0) + + def test_eager_rotate_to_local_forward_backward_matches_reference(self): for lmax in (2, 3, 4, 5): with self.subTest(lmax=lmax): x0, src, w0, coeff_index, dim = self._inputs(lmax, seed=lmax) - mask = _block_mask(lmax, self.device) - - xa = x0.clone().requires_grad_(True) - wa = w0.clone().requires_grad_(True) - out = rotate_to_local(xa, src, wa, coeff_index, dim) - xr = x0.clone().requires_grad_(True) - wr = w0.clone().requires_grad_(True) - ref = rotate_to_local_reference(xr, src, wr, coeff_index, dim) - - torch.testing.assert_close(out, ref, **self.tol) - - grad_out = torch.randn_like(ref) - gxa, gwa = torch.autograd.grad( - out, [xa, wa], grad_out, retain_graph=True - ) - gxr, gwr = torch.autograd.grad(ref, [xr, wr], grad_out) - torch.testing.assert_close(gxa, gxr, **self.tol) - torch.testing.assert_close(gwa[:, mask], gwr[:, mask], **self.tol) - # The kernel never writes off-block Wigner gradient entries. - self.assertEqual(float(gwa[:, ~mask].abs().max()), 0.0) - - def test_rotate_back_matches_reference(self): + self._assert_to_local_matches_reference(x0, src, w0, coeff_index, dim) + + def test_eager_rotate_back_forward_backward_matches_reference(self): for lmax in (2, 3, 4, 5): with self.subTest(lmax=lmax): - _, _, w0, coeff_index, dim = self._inputs(lmax, seed=lmax) - reduced = int(coeff_index.numel()) - gen = torch.Generator(device=self.device).manual_seed(100 + lmax) - xl0 = torch.randn( - self.n_edge, - reduced, - self.channels, - device=self.device, - dtype=self.dtype, - generator=gen, - ) - mask = _block_mask(lmax, self.device) - - xa = xl0.clone().requires_grad_(True) - wa = w0.clone().requires_grad_(True) - out = rotate_back(xa, wa, coeff_index, dim) - xr = xl0.clone().requires_grad_(True) - wr = w0.clone().requires_grad_(True) - ref = rotate_back_reference(xr, wr, coeff_index, dim) - - torch.testing.assert_close(out, ref, **self.tol) - - grad_out = torch.randn_like(ref) - gxa, gwa = torch.autograd.grad( - out, [xa, wa], grad_out, retain_graph=True - ) - gxr, gwr = torch.autograd.grad(ref, [xr, wr], grad_out) - torch.testing.assert_close(gxa, gxr, **self.tol) - torch.testing.assert_close(gwa[:, mask], gwr[:, mask], **self.tol) - - def test_torch_compile_composability(self): - """Gradients through the op match between eager and compiled tracing.""" + xl0, w0, coeff_index, dim = self._local_inputs(lmax, seed=lmax) + self._assert_back_matches_reference(xl0, w0, coeff_index, dim) + + def test_symbolic_make_fx_rotate_to_local_forward_backward_matches_eager(self): + """Symbolic FX captures rotate_to_local forward and autograd graph.""" lmax = 3 x0, src, w0, coeff_index, dim = self._inputs(lmax, seed=7) - weight = torch.randn_like( - rotate_to_local_reference(x0, src, w0, coeff_index, dim) + mask = _block_mask(lmax, self.device) + grad_seed = torch.randn( + self.n_edge, + int(coeff_index.numel()), + self.channels, + device=self.device, + dtype=self.dtype, ) + + def forward_and_grad(x, wigner): + x_req = x.detach().requires_grad_(True) + w_req = wigner.detach().requires_grad_(True) + out = rotate_to_local_block(x_req, src, w_req, coeff_index, dim) + grad_x, grad_wigner = torch.autograd.grad( + out, + (x_req, w_req), + grad_seed, + ) + return out, grad_x, grad_wigner + + out_eager, grad_x_eager, grad_w_eager = forward_and_grad(x0, w0) + + traced = make_fx( + forward_and_grad, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )(x0, w0) + out_traced, grad_x_traced, grad_w_traced = traced(x0, w0) + + torch.testing.assert_close(out_traced, out_eager, **self.tol) + torch.testing.assert_close(grad_x_traced, grad_x_eager, **self.tol) + torch.testing.assert_close( + grad_w_traced[:, mask], + grad_w_eager[:, mask], + **self.tol, + ) + self.assertGreater(float(grad_w_eager[:, mask].abs().max()), 0.0) + self.assertEqual(float(grad_w_traced[:, ~mask].abs().max()), 0.0) + + def test_symbolic_make_fx_rotate_back_forward_backward_matches_eager(self): + """Symbolic FX captures rotate_back forward and autograd graph.""" + lmax = 3 + xl0, w0, coeff_index, dim = self._local_inputs(lmax, seed=7) mask = _block_mask(lmax, self.device) + grad_seed = torch.randn( + self.n_edge, + dim, + self.channels, + device=self.device, + dtype=self.dtype, + ) + + def forward_and_grad(x_local, wigner): + x_req = x_local.detach().requires_grad_(True) + w_req = wigner.detach().requires_grad_(True) + out = rotate_back_block(x_req, w_req, coeff_index, dim) + grad_x, grad_wigner = torch.autograd.grad( + out, + (x_req, w_req), + grad_seed, + ) + return out, grad_x, grad_wigner + + out_eager, grad_x_eager, grad_w_eager = forward_and_grad(xl0, w0) + + traced = make_fx( + forward_and_grad, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )(xl0, w0) + out_traced, grad_x_traced, grad_w_traced = traced(xl0, w0) + + torch.testing.assert_close(out_traced, out_eager, **self.tol) + torch.testing.assert_close(grad_x_traced, grad_x_eager, **self.tol) + torch.testing.assert_close( + grad_w_traced[:, mask], + grad_w_eager[:, mask], + **self.tol, + ) + self.assertGreater(float(grad_w_eager[:, mask].abs().max()), 0.0) + self.assertEqual(float(grad_w_traced[:, ~mask].abs().max()), 0.0) + + def _check_make_fx_force(self, forward_and_grad, eager_args): + """Trace ``forward_and_grad`` (forward + ``autograd.grad``) under symbolic + ``make_fx`` and assert the traced graph reproduces the eager result. + + Returns the eager ``(out, grad_x, grad_wigner)`` triple for further checks. + """ + eager = forward_and_grad(*eager_args) + traced = make_fx( + forward_and_grad, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )(*eager_args) + for got, want in zip(traced(*eager_args), eager, strict=True): + torch.testing.assert_close(got, want, **self.tol) + return eager + + def test_symbolic_make_fx_rotate_to_local_dense_matches_eager_and_reference(self): + """Dense op composes with symbolic make_fx + autograd and matches the eager + reference on a full (non-block) Wigner-D, honoring ``coeff_index``. + """ + lmax = 3 + _, src, _, coeff_index, dim = self._inputs(lmax, seed=11) + gen = torch.Generator(device=self.device).manual_seed(11) + x0 = torch.randn( + self.n_node, + dim, + self.channels, + device=self.device, + dtype=self.dtype, + generator=gen, + ) + w0 = torch.randn( + self.n_edge, dim, dim, device=self.device, dtype=self.dtype, generator=gen + ) + grad_seed = torch.randn( + self.n_edge, + int(coeff_index.numel()), + self.channels, + device=self.device, + dtype=self.dtype, + ) + + def forward_and_grad(x, wigner): + x_req = x.detach().requires_grad_(True) + w_req = wigner.detach().requires_grad_(True) + out = rotate_to_local_dense(x_req, src, w_req, coeff_index, dim) + grad_x, grad_wigner = torch.autograd.grad(out, (x_req, w_req), grad_seed) + return out, grad_x, grad_wigner + + out_eager, grad_x_eager, grad_w_eager = self._check_make_fx_force( + forward_and_grad, (x0, w0) + ) - def scalar_output(x, wigner): - return (rotate_to_local(x, src, wigner, coeff_index, dim) * weight).sum() + xr = x0.detach().requires_grad_(True) + wr = w0.detach().requires_grad_(True) + ref = rotate_to_local_reference(xr, src, wr, coeff_index, dim) + grad_x_ref, grad_w_ref = torch.autograd.grad(ref, (xr, wr), grad_seed) + torch.testing.assert_close(out_eager, ref, **self.tol) + torch.testing.assert_close(grad_x_eager, grad_x_ref, **self.tol) + torch.testing.assert_close(grad_w_eager, grad_w_ref, **self.tol) + self.assertGreater(float(grad_w_eager.abs().max()), 0.0) - xe = x0.clone().requires_grad_(True) - we = w0.clone().requires_grad_(True) - gxe, gwe = torch.autograd.grad(scalar_output(xe, we), [xe, we]) + def test_symbolic_make_fx_rotate_back_dense_matches_eager_and_reference(self): + """Dense rotate_back composes with symbolic make_fx + autograd and matches + the eager reference on a full (non-block) Wigner-D, honoring ``coeff_index``. + """ + lmax = 3 + _, _, _, coeff_index, dim = self._inputs(lmax, seed=11) + gen = torch.Generator(device=self.device).manual_seed(11) + xl0 = torch.randn( + self.n_edge, + int(coeff_index.numel()), + self.channels, + device=self.device, + dtype=self.dtype, + generator=gen, + ) + w0 = torch.randn( + self.n_edge, dim, dim, device=self.device, dtype=self.dtype, generator=gen + ) + grad_seed = torch.randn( + self.n_edge, dim, self.channels, device=self.device, dtype=self.dtype + ) - compiled = torch.compile(scalar_output, dynamic=True) - xc = x0.clone().requires_grad_(True) - wc = w0.clone().requires_grad_(True) - gxc, gwc = torch.autograd.grad(compiled(xc, wc), [xc, wc]) + def forward_and_grad(x_local, wigner): + x_req = x_local.detach().requires_grad_(True) + w_req = wigner.detach().requires_grad_(True) + out = rotate_back_dense(x_req, w_req, coeff_index, dim) + grad_x, grad_wigner = torch.autograd.grad(out, (x_req, w_req), grad_seed) + return out, grad_x, grad_wigner + + out_eager, grad_x_eager, grad_w_eager = self._check_make_fx_force( + forward_and_grad, (xl0, w0) + ) - torch.testing.assert_close(gxc, gxe, **self.tol) - # Also check the Wigner gradient (nonzero on the block entries) survives - # tracing, since it flows through the custom op's registered backward. - torch.testing.assert_close(gwc[:, mask], gwe[:, mask], **self.tol) - self.assertGreater(float(gwe[:, mask].abs().max()), 0.0) + xr = xl0.detach().requires_grad_(True) + wr = w0.detach().requires_grad_(True) + ref = rotate_back_reference(xr, wr, coeff_index, dim) + grad_x_ref, grad_w_ref = torch.autograd.grad(ref, (xr, wr), grad_seed) + torch.testing.assert_close(out_eager, ref, **self.tol) + torch.testing.assert_close(grad_x_eager, grad_x_ref, **self.tol) + torch.testing.assert_close(grad_w_eager, grad_w_ref, **self.tol) + self.assertGreater(float(grad_w_eager.abs().max()), 0.0) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_nv_nlist.py b/source/tests/pt/model/test_nv_nlist.py index 607f0702dd..ede8dcc601 100644 --- a/source/tests/pt/model/test_nv_nlist.py +++ b/source/tests/pt/model/test_nv_nlist.py @@ -8,15 +8,15 @@ dense builder at the nlist level (edge topology + geometry). """ +import contextlib import unittest -from unittest import ( - mock, +from unittest.mock import ( + patch, ) import torch from deepmd.pt.utils import ( - env, nv_nlist, ) from deepmd.pt.utils.nlist import ( @@ -26,6 +26,11 @@ NvNeighborList, ) +_NV_AVAILABLE = nv_nlist.is_nv_available() +_TEST_DEVICES = [torch.device("cpu")] +if torch.cuda.is_available() and torch.cuda.device_count() > 0: + _TEST_DEVICES.append(torch.device("cuda:0")) + def _edge_topology_from_extended( mapping: torch.Tensor, @@ -98,15 +103,12 @@ def _assert_extended_atype_matches_mapping( @unittest.skipUnless( - torch.cuda.is_available() and nv_nlist.is_nv_available(), - "NVIDIA Toolkit-Ops CUDA path is unavailable", + _NV_AVAILABLE, + "NVIDIA Toolkit-Ops neighbor list is unavailable", ) class TestNVNList(unittest.TestCase): - def setUp(self) -> None: - self.device = env.DEVICE - def _build_case( - self, nframes: int + self, nframes: int, device: torch.device ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: coord_one = torch.tensor( [ @@ -116,7 +118,7 @@ def _build_case( [3.8, 3.9, 4.1], ], dtype=torch.float64, - device=self.device, + device=device, ) coord = coord_one.unsqueeze(0).repeat(nframes, 1, 1) if nframes > 1: @@ -130,13 +132,15 @@ def _build_case( dtype=coord.dtype, device=coord.device, ) - atype = torch.tensor([[0, 1, 0, 1]], dtype=torch.int32, device=self.device) + atype = torch.tensor([[0, 1, 0, 1]], dtype=torch.int32, device=device) atype = atype.repeat(nframes, 1) - box = torch.eye(3, dtype=torch.float64, device=self.device).reshape(1, 9) * 8.0 + box = torch.eye(3, dtype=torch.float64, device=device).reshape(1, 9) * 8.0 box = box.repeat(nframes, 1) return coord, atype, box - def _build_overfull_case(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _build_overfull_case( + self, device: torch.device + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: coord = torch.tensor( [ [0.0, 0.0, 0.0], @@ -147,12 +151,10 @@ def _build_overfull_case(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor [2.2, 1.1, 0.0], ], dtype=torch.float64, - device=self.device, + device=device, ).unsqueeze(0) - atype = torch.tensor( - [[0, 1, 0, 1, 0, 1]], dtype=torch.int32, device=self.device - ) - box = torch.eye(3, dtype=torch.float64, device=self.device).reshape(1, 9) * 20.0 + atype = torch.tensor([[0, 1, 0, 1, 0, 1]], dtype=torch.int32, device=device) + box = torch.eye(3, dtype=torch.float64, device=device).reshape(1, 9) * 20.0 return coord, atype, box def _assert_nv_matches_native( @@ -175,11 +177,18 @@ def _assert_nv_matches_native( ) # NeighborList strategy: (extended_coord, extended_atype, nlist, mapping) builder = NvNeighborList() - if force_cell_list: - with mock.patch.object(nv_nlist, "NV_CELL_LIST_THRESHOLD", 1): + # Pin the current CUDA device so the Toolkit-Ops backend launches there. + device_ctx = ( + torch.cuda.device(coord.device) + if coord.is_cuda + else contextlib.nullcontext() + ) + with device_ctx: + if force_cell_list: + with patch.object(nv_nlist, "NV_CELL_LIST_THRESHOLD", 1): + nv = builder.build(coord, atype, box, rcut, sel) + else: nv = builder.build(coord, atype, box, rcut, sel) - else: - nv = builder.build(coord, atype, box, rcut, sel) native_coord, _, native_mapping, native_nlist = native nv_coord, nv_atype, nv_nlist_out, nv_mapping = nv # The strategy trims to sum(sel) itself, so the width is fixed. @@ -203,23 +212,39 @@ def test_cell_list_matches_native(self) -> None: native builder over a multi-frame periodic batch. End-to-end systems are always below the threshold and take ``batch_naive``. """ - coord, atype, box = self._build_case(2) - self._assert_nv_matches_native( - coord=coord, atype=atype, box=box, rcut=3.0, sel=[8], force_cell_list=True - ) + for device in _TEST_DEVICES: + with self.subTest(device=str(device)): + coord, atype, box = self._build_case(2, device) + self._assert_nv_matches_native( + coord=coord, + atype=atype, + box=box, + rcut=3.0, + sel=[8], + force_cell_list=True, + ) def test_overfull_truncates_to_sel(self) -> None: """A center with more real neighbors than ``sum(sel)`` is distance-sorted and trimmed to the nearest ``sum(sel)`` -- the path behind the compiled-graph width bug, which end-to-end systems never reach. """ - coord, atype, box = self._build_overfull_case() - self._assert_nv_matches_native( - coord=coord, atype=atype, box=box, rcut=4.0, sel=[2], force_cell_list=False - ) + for device in _TEST_DEVICES: + with self.subTest(device=str(device)): + coord, atype, box = self._build_overfull_case(device) + self._assert_nv_matches_native( + coord=coord, + atype=atype, + box=box, + rcut=4.0, + sel=[2], + force_cell_list=False, + ) def test_requires_periodic_box(self) -> None: """The cell list needs a periodic box; ``box=None`` is rejected.""" - coord, atype, _ = self._build_case(1) - with self.assertRaises(ValueError): - NvNeighborList().build(coord, atype, None, 3.0, [8]) + for device in _TEST_DEVICES: + with self.subTest(device=str(device)): + coord, atype, _ = self._build_case(1, device) + with self.assertRaises(ValueError): + NvNeighborList().build(coord, atype, None, 3.0, [8]) From 6945ae8f435c70976bd01876eb14825f0a2b9973 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 8 Jun 2026 14:37:12 +0800 Subject: [PATCH 05/18] small --- deepmd/pt/model/descriptor/sezm_nn/block.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/sezm_nn/block.py b/deepmd/pt/model/descriptor/sezm_nn/block.py index 433eb000bf..5857f173a9 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/block.py +++ b/deepmd/pt/model/descriptor/sezm_nn/block.py @@ -572,6 +572,14 @@ def __init__( self.block_attn_res_ffns = None self._forward_impl = self._forward_with_residual_shortcuts + # Inference env policy, sampled once here (see + # ``_use_infer_activation_checkpoint``). + _truthy = {"1", "true", "yes", "on"} + self._act_infer = os.environ.get("DP_ACT_INFER", "").strip().lower() in _truthy + self._compile_infer = ( + os.environ.get("DP_COMPILE_INFER", "").strip().lower() in _truthy + ) + def forward( self, x: torch.Tensor, @@ -638,10 +646,8 @@ def _use_infer_activation_checkpoint(self, *tensors: torch.Tensor) -> bool: """ return ( not self.training - and os.environ.get("DP_ACT_INFER", "").strip().lower() - in {"1", "true", "yes", "on"} - and os.environ.get("DP_COMPILE_INFER", "").strip().lower() - not in {"1", "true", "yes", "on"} + and self._act_infer + and not self._compile_infer and torch.is_grad_enabled() and any(tensor.requires_grad for tensor in tensors) ) From 04136c4dcf67a4f9c1bb245890f3964b15cd99b9 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 8 Jun 2026 19:03:07 +0800 Subject: [PATCH 06/18] tidyup compile compat --- deepmd/pt/model/model/sezm_model.py | 443 +++++------------------ deepmd/pt/train/training.py | 21 +- deepmd/pt/train/utils.py | 21 ++ deepmd/pt/utils/compile_compat.py | 406 +++++++++++++++++++++ source/tests/pt/model/test_sezm_model.py | 30 +- 5 files changed, 531 insertions(+), 390 deletions(-) create mode 100644 deepmd/pt/utils/compile_compat.py diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index b1b98c6f37..f29ba50f06 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -53,6 +53,12 @@ comment tagged ``NOTE:``; the numbered catalogue at the bottom of this docstring explains each tag in depth. +The PyTorch ``torch.compile`` workarounds themselves -- the version +guard, the process-global patches, the post-``make_fx`` FX graph +repair, the trace-shape selection, and the Inductor option lockdown -- +live in :mod:`deepmd.pt.utils.compile_compat` so they can be reused +independently of this model. + Pipeline for one training batch =============================== @@ -73,8 +79,8 @@ | | | * silu_backward is decomposed (NOTE 2) | | | * traced graph already contains the | | | first autograd.grad over coords - | | |-- _strip_saved_tensor_detach (train only) (NOTE 3) - | | |-- _rebuild_graph_module (train only) (NOTE 4) + | | |-- strip_saved_tensor_detach (train only) (NOTE 3) + | | |-- rebuild_graph_module (train only) (NOTE 4) | | |-- train: torch.compile(backend="inductor", | | | dynamic=True, options=) (NOTE 6) | | '-- eval: aot_module_simplified (forward-only) (NOTE 13) @@ -177,7 +183,7 @@ path from the force loss back to ``theta``; training then silently produces zero parameter updates for the second-derivative term. -``_strip_saved_tensor_detach`` removes them by pure graph topology -- +``strip_saved_tensor_detach`` removes them by pure graph topology -- no op-name matching -- so that user-explicit ``.detach()`` calls (e.g. cached SO2 weights, activation lookup matrices) survive: @@ -193,7 +199,7 @@ NOTE 4 -- Rebuilding the FX graph from scratch ---------------------------------------------- -``Graph.erase_node`` inside ``_strip_saved_tensor_detach`` unlinks nodes +``Graph.erase_node`` inside ``strip_saved_tensor_detach`` unlinks nodes from the doubly linked list that represents the graph. On several PyTorch builds (observed on 2.11+cu130) it leaves the C-level ``prev/next`` pointers of *neighbouring* Node objects stale. Dynamo, @@ -201,7 +207,7 @@ inside ``output_graph.py:_create_proxy`` to read ``nd.meta``, dereferences one of those stale pointers and segfaults. -``_rebuild_graph_module`` does a single ``node_copy`` pass into a +``rebuild_graph_module`` does a single ``node_copy`` pass into a freshly allocated ``torch.fx.Graph``. The result is an equivalent graph whose linked list contains no erased entries, so dynamo can iterate it safely. We always rebuild -- including in eval -- because a fresh @@ -450,9 +456,6 @@ from einops import ( rearrange, ) -from packaging.version import ( - Version, -) from torch.fx.experimental.proxy_tensor import ( make_fx, ) @@ -485,6 +488,19 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.compile_compat import ( + AM_PREFIX, + FIT_PREFIX, + apply_global_compile_patches, + build_inductor_compile_options, + check_compile_torch_version, + get_task_buffer_names, + get_task_buffer_values, + next_safe_prime, + rebuild_graph_module, + strip_saved_tensor_detach, + trace_pad_dim, +) from deepmd.pt.utils.nlist import ( extend_input_and_build_neighbor_list, ) @@ -504,28 +520,10 @@ # all-pairs builder for periodic CUDA systems. SEZM_NV_NLIST_THRESHOLD = 1024 -# NOTE: Silence Inductor / Triton autotune dumps before any submodule is -# imported. ``torch.compile`` reads these environment variables exactly -# once at backend initialisation; setting them after the first compile -# would have no effect in the current run. ``setdefault`` preserves any -# explicit user-level override. -os.environ.setdefault("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "0") -os.environ.setdefault("TRITON_PRINT_AUTOTUNING", "0") - -# NOTE: Disable DDPOptimizer graph splitting globally. -# ``compiled_core_compute_cache`` entries / ``compiled_dens_compute`` are inner -# ``torch.compile`` calls sitting *inside* a DDP-wrapped model; -# DDPOptimizer assumes it sees the *whole* model and splits the FX graph -# at DDP bucket boundaries. For an inner submodule that heuristic -# produces subgraphs whose outputs include symbolic integers, which then -# crash aot_autograd with ``'int' object has no attribute 'meta'``. -# See https://github.com/pytorch/pytorch/issues/134182. Turning the -# optimizer off globally is safe because SeZM always owns its own compile -# boundary and the surrounding DDP wrapper operates on the full model -# call. -import torch._dynamo.config as _dynamo_cfg - -_dynamo_cfg.optimize_ddp = False +# Apply the process-global PyTorch workarounds the compile pipeline relies on +# (autotune log suppression, DDP optimiser, and the 2.12 divisibility repair) +# once, before the first compilation in this run. +apply_global_compile_patches() # --------------------------------------------------------------------------- # Multi-task compile sharing @@ -540,9 +538,6 @@ # knows which buffers were promoted and in what order. _SEZM_TASK_BUF_ORDER: dict[tuple[Any, ...], tuple[str, ...]] = {} -# Prefix namespace for promoted buffer names. -_AM_PREFIX = "am/" -_FIT_PREFIX = "fit/" _ENV_BOOL_CHOICES = { "1": True, "true": True, @@ -639,203 +634,6 @@ def _sezm_structure_key(model: SeZMModel) -> tuple[Any, ...]: ) -def _get_sezm_task_buf_names(model: SeZMModel) -> tuple[str, ...]: - """Return the ordered names of per-task buffers to promote as FX placeholders. - - Always promotes: - * ``out_bias``, ``out_std`` on ``atomic_model`` — may be replaced - out-of-place by ``model_change_out_bias``, so the compiled graph must - never bake them as constants. - * ``bias_atom_e`` on the fitting net — task-specific per-type bias that - differs across tasks after ``share_params``. - * ``case_embd`` on the fitting net — task-identity vector used for - multi-task case conditioning. - """ - names: list[str] = [] - atomic_model = model.atomic_model - fitting = atomic_model.fitting_net - for bname in ("out_bias", "out_std"): - if atomic_model._buffers.get(bname) is not None: - names.append(_AM_PREFIX + bname) - for bname in ("bias_atom_e", "case_embd"): - if fitting._buffers.get(bname) is not None: - names.append(_FIT_PREFIX + bname) - return tuple(names) - - -def _get_sezm_task_buf_vals( - model: SeZMModel, - names: tuple[str, ...], -) -> tuple[torch.Tensor, ...]: - """Return the current tensor values for the given promoted-buffer names.""" - if not names: - return () - atomic_model = model.atomic_model - fitting = atomic_model.fitting_net - vals: list[torch.Tensor] = [] - for name in names: - if name.startswith(_AM_PREFIX): - vals.append(atomic_model._buffers[name[len(_AM_PREFIX) :]]) - elif name.startswith(_FIT_PREFIX): - vals.append(fitting._buffers[name[len(_FIT_PREFIX) :]]) - else: - raise ValueError(f"Unknown SeZM task-buffer name: {name}") - return tuple(vals) - - -def _check_compile_torch_version() -> None: - """Fail fast when SeZM compile is requested on unsupported PyTorch.""" - version = Version(torch.__version__).release - if len(version) < 2 or version[:2] != (2, 11): - raise RuntimeError( - "SeZM `use_compile` and `DP_COMPILE_INFER` require PyTorch 2.11.x; " - f"found torch {torch.__version__}." - ) - - -def _is_prime(n: int) -> bool: - """Return True when ``n`` is a prime integer (``n >= 2``).""" - if n < 2: - return False - if n < 4: - return True - if n % 2 == 0: - return False - k = 3 - while k * k <= n: - if n % k == 0: - return False - k += 2 - return True - - -def _next_safe_prime(start: int, forbidden: set[int]) -> int: - """Return the smallest prime ``>= max(start, 5)`` not in ``forbidden``. - - Used by :meth:`SeZMModel.trace_and_compile` to choose collision-free - trace-time sizes for ``nf``, ``nall`` and ``nloc``. Primes ``>= 5`` - avoid every dim PyTorch specializes on (``1`` → broadcasting, - ``2``/``3``/``9`` → Cartesian / virial / charge_spin literals baked - into model code) and guarantee distinct values, which suppresses - make_fx's duck-shape unification without needing the - ``ShapeEnv(duck_shape=False)`` patch. - """ - n = max(start, 5) - while not _is_prime(n) or n in forbidden: - n += 1 - return n - - -def _trace_pad_dim(t: torch.Tensor, dim: int, target: int) -> torch.Tensor: - """Pad or trim ``t`` along ``dim`` so ``t.shape[dim] == target``. - - Padding duplicates the last slice along ``dim``; trimming drops - trailing slices. Used to coerce real-data trace inputs into the - prime-numbered shapes chosen by :func:`_next_safe_prime`. - - Duplicating the last slice preserves valid index values inside - index-bearing tensors (``nlist`` neighbor indices, ``mapping`` - extended-to-local indices) because the duplicated row reuses the - previously-valid row's values. Trimming likewise never invalidates - indices. Only shapes flow downstream during ``make_fx`` tracing, - so the exact replicated/trimmed values do not affect the FX graph. - """ - cur = int(t.shape[dim]) - if cur == target: - return t - if cur > target: - sl: list[slice] = [slice(None)] * t.ndim - sl[dim] = slice(None, target) - return t[tuple(sl)] - sl = [slice(None)] * t.ndim - sl[dim] = slice(-1, None) - last = t[tuple(sl)] - repeats = target - cur - return torch.cat([t, *([last] * repeats)], dim=dim) - - -def _strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None: - """Strip ``aten.detach`` nodes that ``make_fx`` inserts for saved tensors. - - When ``make_fx`` decomposes ``autograd.grad(..., create_graph=True)``, - the autograd engine wraps every saved forward activation in a double-detach - chain (e.g. ``tanh -> detach_A -> detach_B -> tanh_backward``). These - detach nodes block the second-order gradient path from the loss back to - model parameters, causing incorrect parameter updates during force-loss - training. - - User-explicit ``.detach()`` calls (e.g. inside ``attach_edge_vec_grad``) - are preserved. The two categories are distinguished by graph topology - alone — no hard-coded op names — using three rules: - - * *Chain inner*: input is another detach node. - * *Dead node*: no downstream users. - * *Chain head*: *all* users are detach nodes. - - Any detach that does **not** match these rules is treated as user-explicit - and left untouched. - """ - _DETACH = torch.ops.aten.detach.default - - def _is_detach(n: torch.fx.Node) -> bool: - return n.op == "call_function" and n.target == _DETACH - - # NOTE: Pass 1 -- classify every detach against the *original* graph. - # If we erased nodes eagerly, later classifications would walk a - # mutated neighbourhood and misjudge the chain-inner / chain-head / - # dead boundaries; the double-detach pattern in particular flips - # class within a single erase. Collecting first, mutating second - # keeps the topology rules well-defined. - to_remove: list[torch.fx.Node] = [] - for node in gm.graph.nodes: - if not _is_detach(node): - continue - input_node = node.args[0] - users = list(node.users.keys()) - is_chain_inner = _is_detach(input_node) - is_dead = len(users) == 0 - is_chain_head = len(users) > 0 and all(_is_detach(u) for u in users) - if is_chain_inner or is_dead or is_chain_head: - to_remove.append(node) - - # NOTE: Pass 2 -- rewire + erase atomically after the full - # classification. ``replace_all_uses_with`` forwards every consumer - # to the detach's input; ``erase_node`` then removes the now-dead - # detach. Doing both back-to-back means the graph never sits in a - # half-consistent state where one user sees the old detach and - # another the rewired source. - for node in to_remove: - node.replace_all_uses_with(node.args[0]) - gm.graph.erase_node(node) - - gm.graph.lint() - gm.recompile() - - -def _rebuild_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - """Return a fresh ``GraphModule`` whose node linked-list is newly allocated. - - After ``_strip_saved_tensor_detach`` erases nodes via - ``Graph.erase_node()``, the internal doubly-linked list may retain - stale pointers to erased nodes. When ``torch.compile`` later - triggers dynamo re-tracing and iterates ``graph.nodes`` to read - ``nd.meta`` (``output_graph.py:_create_proxy``), accessing these - stale entries causes a segfault. - - Copying every node into a brand-new ``Graph`` builds a clean linked - list from scratch, side-stepping the corruption entirely. - """ - old_graph = gm.graph - new_graph = torch.fx.Graph() - # node_copy needs a mapper from old nodes to their copies in new_graph. - val_map: dict[torch.fx.Node, torch.fx.Node] = {} - for node in old_graph.nodes: - val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) - new_graph.lint() - new_gm = torch.fx.GraphModule(gm, new_graph) - return new_gm - - @BaseModel.register("SeZM") @BaseModel.register("sezm") @BaseModel.register("DPA4") @@ -914,7 +712,7 @@ def __init__( ) self._tf32_infer_precision = _TF32_INFER_PRECISION_CHOICES[tf32_infer_env] if self.use_compile or self._env_use_compile_infer is True: - _check_compile_torch_version() + check_compile_torch_version() # === Bridging (optional short-range zone bridging) === self.bridging_method: str = str(bridging_method).upper() @@ -1298,7 +1096,7 @@ def forward_common_after_nlist( # update them in-place; out-of-place replacements from # model_change_out_bias are captured because we read fresh # each call rather than caching the values at compile time). - _task_buf_vals = _get_sezm_task_buf_vals( + _task_buf_vals = get_task_buffer_values( self, self._task_buf_order_cache[cache_key], ) @@ -1856,13 +1654,11 @@ def trace_and_compile( has_coord_corr, ) - # --- Detect per-task buffers to promote as FX placeholders --- - # These buffers differ across tasks in the same structure group (they are - # NOT shared by share_params) or may be replaced out-of-place after - # compilation. Passing them as explicit arguments makes the compiled - # graph reusable across all tasks in the group. - task_buf_names = _get_sezm_task_buf_names(self) - task_buf_vals_trace = _get_sezm_task_buf_vals(self, task_buf_names) + # Promote the per-task buffers (see ``get_task_buffer_names``) to + # explicit graph inputs so one compiled graph serves the whole task + # group regardless of their per-task values. + task_buf_names = get_task_buffer_names(self) + task_buf_vals_trace = get_task_buffer_values(self, task_buf_names) # Resolve module references once for the buffer-patching closures. _am_patch = self.atomic_model @@ -1887,12 +1683,12 @@ def _patch_task_bufs( saved: dict[str, torch.Tensor] = {} try: for name, val in zip(task_buf_names, vals): - if name.startswith(_AM_PREFIX): - actual = name[len(_AM_PREFIX) :] + if name.startswith(AM_PREFIX): + actual = name[len(AM_PREFIX) :] saved[name] = _am_patch._buffers[actual] _am_patch._buffers[actual] = val - elif name.startswith(_FIT_PREFIX): - actual = name[len(_FIT_PREFIX) :] + elif name.startswith(FIT_PREFIX): + actual = name[len(FIT_PREFIX) :] saved[name] = _fitting_patch._buffers[actual] _fitting_patch._buffers[actual] = val except Exception: @@ -1905,11 +1701,11 @@ def _restore_task_bufs( ) -> None: """Restore original task-local buffers after tracing.""" for name, orig in saved.items(): - if name.startswith(_AM_PREFIX): - actual = name[len(_AM_PREFIX) :] + if name.startswith(AM_PREFIX): + actual = name[len(AM_PREFIX) :] _am_patch._buffers[actual] = orig - elif name.startswith(_FIT_PREFIX): - actual = name[len(_FIT_PREFIX) :] + elif name.startswith(FIT_PREFIX): + actual = name[len(FIT_PREFIX) :] _fitting_patch._buffers[actual] = orig need_coord_grad = self.do_grad_r() or self.do_grad_c() @@ -1999,29 +1795,12 @@ def compute_fn( # type: ignore[misc] finally: _restore_task_bufs(_saved) - # NOTE: Choose trace shapes that are pairwise-distinct primes >= 5. - # - # ``make_fx(tracing_mode="symbolic")`` introduces a sympy symbol per - # input dim. Two failure modes follow if those dims accidentally - # match each other or hit a PyTorch-internal "special" value: - # - # * Duck-shape unification: two input dims that share a concrete - # value at trace time get the SAME sympy symbol, baking an - # equality (``nloc == ntypes``, ``nloc == nall``, ...) the - # compiled graph will violate on later batches. - # * Size specialization: dims equal to ``1`` are baked as literal - # ``1`` regardless of duck-shape; values ``2``/``3``/``9`` are - # commonly literals inside the model (charge/spin width, - # Cartesian, virial) and may be unified with input symbols by - # ShapeEnv even with duck-shape off. - # - # Picking pairwise-distinct primes ``>= 5`` for ``nf``, ``nall``, - # ``nloc`` rules out both failure modes in one stroke: no two - # symbols can fuse (distinct values), and no symbol can hit a - # special literal (``5+`` primes skip ``1``/``2``/``3``/``9``). - # ``nsel``, ``dim_fparam``, ``dim_aparam`` and ``dim_chg_spin`` are - # contractually fixed by the model and added to the forbidden set - # so the chosen primes never collide with them either. + # Trace dims are pairwise-distinct primes >= 5 so ``make_fx`` neither + # unifies two axes onto one symbol (duck-shape) nor specializes an axis + # on a literal; ``next_safe_prime`` documents why. The forbidden set + # adds the model-contracted dims (``nsel``, fparam / aparam widths, + # charge_spin) and the promoted task-buffer dims so the chosen primes + # never collide with them. _forbidden: set[int] = {1, 2, 3, 9} for _tbv in task_buf_vals_trace: for _d in _tbv.shape: @@ -2044,24 +1823,21 @@ def compute_fn( # type: ignore[misc] # ``trace_nloc > trace_nall`` the slice silently truncates at # trace time, breaking the captured symbolic shape relation # ``atype.shape[1] == nloc``. - trace_nf = _next_safe_prime(5, _forbidden) + trace_nf = next_safe_prime(5, _forbidden) _forbidden.add(trace_nf) - trace_nloc = _next_safe_prime(trace_nf + 1, _forbidden) + trace_nloc = next_safe_prime(trace_nf + 1, _forbidden) _forbidden.add(trace_nloc) - trace_nall = _next_safe_prime(trace_nloc + 1, _forbidden) - - # Build trace inputs by padding/trimming real-data tensors into - # the chosen prime shapes. ``_trace_pad_dim`` duplicates the - # last slice when padding so index-bearing tensors (``nlist`` - # neighbor indices, ``mapping`` extended-to-local indices) keep - # valid values -- the duplicated row references the same atoms - # the previous row referenced. - coord_for_trace = _trace_pad_dim(extended_coord[:1], 0, trace_nf) - coord_for_trace = _trace_pad_dim(coord_for_trace, 1, trace_nall) - atype_for_trace = _trace_pad_dim(extended_atype[:1], 0, trace_nf) - atype_for_trace = _trace_pad_dim(atype_for_trace, 1, trace_nall) - nlist_for_trace = _trace_pad_dim(nlist[:1], 0, trace_nf) - nlist_for_trace = _trace_pad_dim(nlist_for_trace, 1, trace_nloc) + trace_nall = next_safe_prime(trace_nloc + 1, _forbidden) + + # Build trace inputs by padding/trimming real-data tensors into the + # chosen prime shapes; ``trace_pad_dim`` documents how index-bearing + # tensors keep valid values. + coord_for_trace = trace_pad_dim(extended_coord[:1], 0, trace_nf) + coord_for_trace = trace_pad_dim(coord_for_trace, 1, trace_nall) + atype_for_trace = trace_pad_dim(extended_atype[:1], 0, trace_nf) + atype_for_trace = trace_pad_dim(atype_for_trace, 1, trace_nall) + nlist_for_trace = trace_pad_dim(nlist[:1], 0, trace_nf) + nlist_for_trace = trace_pad_dim(nlist_for_trace, 1, trace_nloc) # Real nlist values are in ``[-1, real_nall)`` (``-1`` marks # padded slots, non-negative entries index into extended_coord). # After trimming ``nall`` down to ``trace_nall`` some of those @@ -2071,17 +1847,17 @@ def compute_fn( # type: ignore[misc] # ``trace_nall - 1`` (the ``-1`` padding stays untouched since # clamp only caps the high side). nlist_for_trace = torch.clamp(nlist_for_trace, max=trace_nall - 1) - mapping_for_trace = _trace_pad_dim(mapping[:1], 0, trace_nf) - mapping_for_trace = _trace_pad_dim(mapping_for_trace, 1, trace_nall) + mapping_for_trace = trace_pad_dim(mapping[:1], 0, trace_nf) + mapping_for_trace = trace_pad_dim(mapping_for_trace, 1, trace_nall) # Real mapping values are in ``[0, real_nloc)``. If # ``trace_nloc < real_nloc`` they can exceed ``trace_nloc`` and # silently propagate into ``src_local`` (used as a local-atom # index downstream). Clamp to ``trace_nloc - 1``. mapping_for_trace = torch.clamp(mapping_for_trace, min=0, max=trace_nloc - 1) - fp_for_trace = _trace_pad_dim(fp[:1], 0, trace_nf) - ap_for_trace = _trace_pad_dim(ap[:1], 0, trace_nf) - ap_for_trace = _trace_pad_dim(ap_for_trace, 1, trace_nloc) - charge_spin_for_trace = _trace_pad_dim(charge_spin[:1], 0, trace_nf) + fp_for_trace = trace_pad_dim(fp[:1], 0, trace_nf) + ap_for_trace = trace_pad_dim(ap[:1], 0, trace_nf) + ap_for_trace = trace_pad_dim(ap_for_trace, 1, trace_nloc) + charge_spin_for_trace = trace_pad_dim(charge_spin[:1], 0, trace_nf) trace_args = [ coord_for_trace, @@ -2093,8 +1869,8 @@ def compute_fn( # type: ignore[misc] charge_spin_for_trace, ] if extended_coord_corr is not None: - corr_for_trace = _trace_pad_dim(extended_coord_corr[:1], 0, trace_nf) - corr_for_trace = _trace_pad_dim(corr_for_trace, 1, trace_nall) + corr_for_trace = trace_pad_dim(extended_coord_corr[:1], 0, trace_nf) + corr_for_trace = trace_pad_dim(corr_for_trace, 1, trace_nall) trace_args.append(corr_for_trace) # Append task-buffer values last so they map to the *task_buf_vals # varargs in compute_fn. Their shapes are static (they don't vary @@ -2130,70 +1906,23 @@ def compute_fn( # type: ignore[misc] )(*trace_args) if self.training: - # NOTE: Training is the only mode that needs FX graph repair. - # ``fit_output_to_model_output(create_graph=True)`` asks autograd to - # keep the force graph differentiable with respect to model - # parameters. During ``make_fx`` tracing, autograd represents saved - # forward activations through double-detach chains such as + # Only the training trace runs with ``create_graph=True``, so only + # it carries the autograd-inserted detach chains that the FX repair + # targets; ``strip_saved_tensor_detach`` and ``rebuild_graph_module`` + # document why each step is required. # - # fwd_op -> detach_A -> detach_B -> bwd_op - # - # These detaches are bookkeeping in eager autograd, but ordinary FX - # operators after tracing. If left in place, they cut the - # second-derivative path from force loss back to theta and training - # silently produces zero updates for that term. Therefore the - # training graph first removes only the autograd-inserted detach - # chains, preserving user-explicit detach nodes by graph topology. - _strip_saved_tensor_detach(traced) - - # ``_strip_saved_tensor_detach`` mutates ``traced.graph`` via - # ``Graph.erase_node``. On some PyTorch builds (observed on - # 2.11+cu130), node erasure may leave stale C-level prev/next - # pointers on neighbouring FX nodes; Dynamo can later dereference - # those stale links while re-tracing the GraphModule and segfault. - # Rebuilding copies the graph into a fresh linked list after all - # training-only erasures are complete. - # - # Eval/inference must not take this repair path. In eval, - # ``create_graph=False`` means autograd does not insert the - # double-detach chains, so no nodes are erased. The eval graph also - # contains data-dependent ``nonzero`` output sizes from sparse edge - # compaction; copying that graph can make the resulting unbacked - # symbols fail Dynamo's shape-guard generation. Keeping the original - # eval GraphModule preserves the traced metadata that Inductor needs. - traced = _rebuild_graph_module(traced) - - # NOTE: Conservative Inductor options keep SeZM's dynamic edge - # graph from forming overly large Triton reduction kernels - # (``make_ttgir`` / ``PassManager::run failed``) on some - # GPU/Triton combinations. - compile_options: dict[str, Any] = { - "max_autotune": False, - "shape_padding": True, - "epilogue_fusion": False, - "triton.cudagraphs": False, - "max_fusion_size": 8, - "triton.persistent_reductions": False, - # NOTE: ``mix_order_reduction`` hits multiple bugs under - # data-dependent symbolic shapes on PyTorch <=2.11 - # (pytorch/pytorch#174379, #178080, #179494) -- our edge - # count is exactly that kind of shape. - "triton.mix_order_reduction": False, - } - try: - from torch._inductor import config as inductor_config - - valid_options = inductor_config.get_config_copy() - compile_options = { - key: value - for key, value in compile_options.items() - if key.replace("-", "_") in valid_options - } - except Exception: - # Older/future PyTorch builds may not expose the config registry. - # In that case keep the curated option set and let torch.compile - # surface any real backend error. - pass + # Eval/inference must not take this path: ``create_graph=False`` + # inserts no detach chains to strip, and the eval graph carries + # data-dependent ``nonzero`` sizes from sparse edge compaction whose + # unbacked symbols fail Dynamo's shape-guard generation once the + # graph is copied. The original eval GraphModule is kept so Inductor + # sees the traced metadata it needs. + strip_saved_tensor_detach(traced) + traced = rebuild_graph_module(traced) + + # The conservative Inductor option set that keeps the dynamic edge + # graph lowerable is centralised in ``deepmd.pt.utils.compile_compat``. + compile_options = build_inductor_compile_options() # NOTE: Store the compiled callable inside the plain-``dict`` # cache ``compiled_core_compute_cache``. The dict itself was installed diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 48e550729f..712ffa3f44 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -3,7 +3,6 @@ import functools import json import logging -import os import time from collections.abc import ( Callable, @@ -11,7 +10,6 @@ Iterable, ) from contextlib import ( - contextmanager, nullcontext, ) from copy import ( @@ -76,6 +74,7 @@ ) from deepmd.pt.train.utils import ( clip_grad_norm_with_stable_fallback, + scoped_env_defaults, ) from deepmd.pt.train.validation import ( FullValidator, @@ -148,22 +147,6 @@ log = logging.getLogger(__name__) -@contextmanager -def _scoped_env_defaults(defaults: dict[str, str]) -> Generator[None, None, None]: - """Temporarily set missing environment variables and restore them afterward.""" - previous = {key: os.environ.get(key) for key in defaults} - try: - for key, value in defaults.items(): - os.environ.setdefault(key, value) - yield - finally: - for key, value in previous.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value - - class Trainer: def __init__( self, @@ -462,7 +445,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: # Model # SeZMModel samples these eval/inference env vars exactly once inside # __init__; keep config-derived defaults scoped to construction. - with _scoped_env_defaults(infer_env_defaults): + with scoped_env_defaults(infer_env_defaults): self.model = get_model_for_wrapper( model_params, resuming=resuming, diff --git a/deepmd/pt/train/utils.py b/deepmd/pt/train/utils.py index 2cbf536ac2..c4074ce514 100644 --- a/deepmd/pt/train/utils.py +++ b/deepmd/pt/train/utils.py @@ -6,6 +6,10 @@ ) import math +import os +from contextlib import ( + contextmanager, +) from typing import ( TYPE_CHECKING, ) @@ -15,6 +19,7 @@ if TYPE_CHECKING: from collections.abc import ( Callable, + Generator, Iterable, ) @@ -148,3 +153,19 @@ def stable_clip_grad_norm( param.grad.detach().mul_(clip_coef) return torch.tensor(total_norm, dtype=torch.float64, device=first_device) + + +@contextmanager +def scoped_env_defaults(defaults: dict[str, str]) -> Generator[None, None, None]: + """Temporarily set missing environment variables and restore them afterward.""" + previous = {key: os.environ.get(key) for key in defaults} + try: + for key, value in defaults.items(): + os.environ.setdefault(key, value) + yield + finally: + for key, value in previous.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value diff --git a/deepmd/pt/utils/compile_compat.py b/deepmd/pt/utils/compile_compat.py new file mode 100644 index 0000000000..76c6c0c046 --- /dev/null +++ b/deepmd/pt/utils/compile_compat.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""PyTorch ``torch.compile`` support for the deepmd backend. + +The deepmd PyTorch backend traces selected compute functions with ``make_fx`` +and lowers them through Inductor while preserving the second-order autograd +graph that force training requires. This module gathers the helpers that +support that pipeline together with the PyTorch defect workarounds it needs, so +the model code stays free of compiler plumbing. + +The contents fall into two kinds: + +* helpers and workarounds common to the supported releases -- trace-shape and + trace-input preparation, per-task buffer promotion, FX graph repair, the + Inductor option lockdown, and the process-global configuration; and +* a workaround specific to PyTorch 2.12, which must not be applied on 2.11. + +Only PyTorch 2.11.x and 2.12.x are permitted for compilation (see +:func:`check_compile_torch_version`). +""" + +from __future__ import ( + annotations, +) + +import os +from typing import ( + Any, +) + +import torch +from packaging.version import ( + Version, +) + +__all__ = [ + "AM_PREFIX", + "FIT_PREFIX", + "apply_global_compile_patches", + "build_inductor_compile_options", + "check_compile_torch_version", + "get_task_buffer_names", + "get_task_buffer_values", + "is_prime", + "next_safe_prime", + "patch_inductor_symbolic_divisibility", + "rebuild_graph_module", + "strip_saved_tensor_detach", + "trace_pad_dim", +] + + +# ============================================================================= +# Common workarounds (PyTorch 2.11 and 2.12) +# ============================================================================= +def apply_global_compile_patches() -> None: + """Apply every process-global PyTorch adjustment the compile path needs. + + The adjustments are mutually independent and individually idempotent. The + function is intended to run exactly once, when the model module is + imported, so that the global state is established before the first + compilation. The symbolic-divisibility repair is applied only on PyTorch + 2.12, where the regression exists. + """ + # Silence Inductor / Triton autotune console dumps. ``torch.compile`` + # reads these environment variables once, when its backend is first + # initialised, so they must be set before the first compilation; setting + # them afterwards has no effect in the current run. ``setdefault`` + # preserves any explicit user-level override. + os.environ.setdefault("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "0") + os.environ.setdefault("TRITON_PRINT_AUTOTUNING", "0") + + # Disable DDPOptimizer graph splitting globally. The inner + # ``torch.compile`` calls sit *inside* a DDP-wrapped model; DDPOptimizer + # assumes it sees the *whole* model and splits the FX graph at DDP bucket + # boundaries. For an inner submodule that heuristic produces subgraphs + # whose outputs include symbolic integers, which then crash aot_autograd + # with ``'int' object has no attribute 'meta'``. + # See https://github.com/pytorch/pytorch/issues/134182. Turning the + # optimizer off globally is safe because the compile region always owns its + # own boundary and the surrounding DDP wrapper operates on the full model + # call. + import torch._dynamo.config as dynamo_config + + dynamo_config.optimize_ddp = False + + # The symbolic-divisibility regression exists only on PyTorch 2.12; the + # 2.11 backend evaluates the same predicate correctly and must not be + # patched. + if Version(torch.__version__).release[:2] == (2, 12): + patch_inductor_symbolic_divisibility() + + +def check_compile_torch_version() -> None: + """Fail fast when ``torch.compile`` is requested on an unsupported PyTorch.""" + version = Version(torch.__version__).release + if len(version) < 2 or (version[:2] != (2, 11) and version[:2] != (2, 12)): + raise RuntimeError( + "deepmd `torch.compile` support requires PyTorch 2.11.x or 2.12.x; " + f"found torch {torch.__version__}." + ) + + +def is_prime(n: int) -> bool: + """Return True when ``n`` is a prime integer (``n >= 2``).""" + if n < 2: + return False + if n < 4: + return True + if n % 2 == 0: + return False + k = 3 + while k * k <= n: + if n % k == 0: + return False + k += 2 + return True + + +def next_safe_prime(start: int, forbidden: set[int]) -> int: + """Return the smallest prime ``>= max(start, 5)`` not in ``forbidden``. + + Used by the ``make_fx`` symbolic-tracing path to choose collision-free + trace-time sizes for ``nf``, ``nall`` and ``nloc``. Primes ``>= 5`` + avoid every dim PyTorch specializes on (``1`` → broadcasting, + ``2``/``3``/``9`` → Cartesian / virial / charge_spin literals baked + into model code) and guarantee distinct values, which suppresses + make_fx's duck-shape unification without needing the + ``ShapeEnv(duck_shape=False)`` patch. + """ + n = max(start, 5) + while not is_prime(n) or n in forbidden: + n += 1 + return n + + +def trace_pad_dim(t: torch.Tensor, dim: int, target: int) -> torch.Tensor: + """Pad or trim ``t`` along ``dim`` so ``t.shape[dim] == target``. + + Padding duplicates the last slice along ``dim``; trimming drops + trailing slices. Used to coerce real-data trace inputs into the + prime-numbered shapes chosen by :func:`next_safe_prime`. + + Duplicating the last slice preserves valid index values inside + index-bearing tensors (``nlist`` neighbor indices, ``mapping`` + extended-to-local indices) because the duplicated row reuses the + previously-valid row's values. Trimming likewise never invalidates + indices. Only shapes flow downstream during ``make_fx`` tracing, + so the exact replicated/trimmed values do not affect the FX graph. + """ + cur = int(t.shape[dim]) + if cur == target: + return t + if cur > target: + sl: list[slice] = [slice(None)] * t.ndim + sl[dim] = slice(None, target) + return t[tuple(sl)] + sl = [slice(None)] * t.ndim + sl[dim] = slice(-1, None) + last = t[tuple(sl)] + repeats = target - cur + return torch.cat([t, *([last] * repeats)], dim=dim) + + +def strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None: + """Strip ``aten.detach`` nodes that ``make_fx`` inserts for saved tensors. + + When ``make_fx`` decomposes ``autograd.grad(..., create_graph=True)``, + the autograd engine wraps every saved forward activation in a double-detach + chain (e.g. ``tanh -> detach_A -> detach_B -> tanh_backward``). These + detach nodes block the second-order gradient path from the loss back to + model parameters, causing incorrect parameter updates during force-loss + training. + + User-explicit ``.detach()`` calls are preserved. The two categories are + distinguished by graph topology alone — no hard-coded op names — using + three rules: + + * *Chain inner*: input is another detach node. + * *Dead node*: no downstream users. + * *Chain head*: *all* users are detach nodes. + + Any detach that does **not** match these rules is treated as user-explicit + and left untouched. + """ + _DETACH = torch.ops.aten.detach.default + + def _is_detach(n: torch.fx.Node) -> bool: + return n.op == "call_function" and n.target == _DETACH + + # Pass 1 classifies every detach against the original graph. Erasing + # nodes eagerly would let later classifications inspect a mutated + # neighbourhood and misjudge the chain-interior / chain-head / dead + # boundaries; the double-detach pattern in particular changes category + # within a single erase. Classifying first and mutating second keeps the + # topology rules well defined. + to_remove: list[torch.fx.Node] = [] + for node in gm.graph.nodes: + if not _is_detach(node): + continue + input_node = node.args[0] + users = list(node.users.keys()) + is_chain_inner = _is_detach(input_node) + is_dead = len(users) == 0 + is_chain_head = len(users) > 0 and all(_is_detach(u) for u in users) + if is_chain_inner or is_dead or is_chain_head: + to_remove.append(node) + + # Pass 2 rewires and erases after classification is complete. + # ``replace_all_uses_with`` forwards every consumer to the detach's input + # and ``erase_node`` removes the now-dead detach, so the graph never holds + # a partially redirected state. + for node in to_remove: + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) + + gm.graph.lint() + gm.recompile() + + +def rebuild_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Return a fresh ``GraphModule`` whose node linked-list is newly allocated. + + After ``strip_saved_tensor_detach`` erases nodes via + ``Graph.erase_node()``, the internal doubly-linked list may retain + stale pointers to erased nodes. When ``torch.compile`` later + triggers dynamo re-tracing and iterates ``graph.nodes`` to read + ``nd.meta`` (``output_graph.py:_create_proxy``), accessing these + stale entries causes a segfault. + + Copying every node into a brand-new ``Graph`` builds a clean linked + list from scratch, side-stepping the corruption entirely. + """ + old_graph = gm.graph + new_graph = torch.fx.Graph() + # node_copy needs a mapper from old nodes to their copies in new_graph. + val_map: dict[torch.fx.Node, torch.fx.Node] = {} + for node in old_graph.nodes: + val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + return new_gm + + +def build_inductor_compile_options() -> dict[str, Any]: + """Return the conservative Inductor options used to lower the dynamic graph. + + The option set disables every Inductor and Triton feature that has + misbehaved on the combination of data-dependent edge counts and a + second-order autograd graph -- most visibly the oversized fused Triton + reduction kernels that fail ``make_ttgir`` (``PassManager::run failed``) on + some GPU/Triton combinations. Options absent from the running PyTorch's + configuration registry are dropped so the returned dictionary stays valid + across releases. + """ + compile_options: dict[str, Any] = { + "max_autotune": False, + "shape_padding": True, + "epilogue_fusion": False, + "triton.cudagraphs": False, + "max_fusion_size": 8, + "triton.persistent_reductions": False, + # ``mix_order_reduction`` is defective under data-dependent symbolic + # shapes on PyTorch 2.11 and earlier (pytorch/pytorch#174379, #178080, + # #179494); the edge count is exactly that kind of shape. + "triton.mix_order_reduction": False, + } + try: + from torch._inductor import config as inductor_config + + valid_options = inductor_config.get_config_copy() + compile_options = { + key: value + for key, value in compile_options.items() + if key.replace("-", "_") in valid_options + } + except Exception: + # Older/future PyTorch builds may not expose the config registry. + # In that case keep the curated option set and let torch.compile + # surface any real backend error. + pass + return compile_options + + +# Prefix namespace for promoted per-task buffer names. +AM_PREFIX = "am/" +FIT_PREFIX = "fit/" + + +def get_task_buffer_names(model: Any) -> tuple[str, ...]: + """Return the ordered names of per-task buffers to promote as FX placeholders. + + ``model`` is any deepmd model exposing ``atomic_model`` and + ``atomic_model.fitting_net``. Promoting these buffers as explicit graph + inputs lets one compiled graph stay correct across tasks that differ only + in their values. Always promotes: + + * ``out_bias``, ``out_std`` on ``atomic_model`` -- may be replaced + out-of-place by ``model_change_out_bias``, so the compiled graph must + never bake them as constants. + * ``bias_atom_e`` on the fitting net -- task-specific per-type bias that + differs across tasks after ``share_params``. + * ``case_embd`` on the fitting net -- task-identity vector used for + multi-task case conditioning. + """ + names: list[str] = [] + atomic_model = model.atomic_model + fitting = atomic_model.fitting_net + for bname in ("out_bias", "out_std"): + if atomic_model._buffers.get(bname) is not None: + names.append(AM_PREFIX + bname) + for bname in ("bias_atom_e", "case_embd"): + if fitting._buffers.get(bname) is not None: + names.append(FIT_PREFIX + bname) + return tuple(names) + + +def get_task_buffer_values( + model: Any, + names: tuple[str, ...], +) -> tuple[torch.Tensor, ...]: + """Return the current tensor values for the given promoted-buffer names.""" + if not names: + return () + atomic_model = model.atomic_model + fitting = atomic_model.fitting_net + vals: list[torch.Tensor] = [] + for name in names: + if name.startswith(AM_PREFIX): + vals.append(atomic_model._buffers[name[len(AM_PREFIX) :]]) + elif name.startswith(FIT_PREFIX): + vals.append(fitting._buffers[name[len(FIT_PREFIX) :]]) + else: + raise ValueError(f"Unknown task-buffer name: {name}") + return tuple(vals) + + +# ============================================================================= +# PyTorch 2.12-specific workarounds +# ============================================================================= +def patch_inductor_symbolic_divisibility() -> None: + """Repair the PyTorch 2.12 Inductor symbolic-divisibility regression. + + ``SizeVarAllocator.statically_known_multiple_of`` determines whether one + symbolic size is an exact multiple of another. ``SIMDKernel`` consults it + while splitting a fused iteration space into kernel groups and raises + ``CantSplit`` whenever the test reports a non-multiple. + + PyTorch 2.11 evaluated the test with sympy's native modulo operator, which + factors polynomials, so an expression such as ``(32*s + 64) % (s + 2)`` + reduces to ``0`` and the split proceeds. PyTorch 2.12 rewrote the helper + and, for symbolic denominators, routes the test through Inductor's own + ``Mod`` implementation, which does not factor. ``Mod(32*s + 64, s + 2)`` + therefore stays unevaluated, the test returns ``False``, and lowering + aborts with:: + + CantSplit: 32*s38 + 64 not divisible by s38 + 2 + + Such ``c * (s + k)`` over ``(s + k)`` patterns arise whenever a padded axis + is multiplied by a constant channel count, which is common in the compiled + descriptor graph. + + The wrapper re-tests a symbolic denominator that the original rejects, this + time with sympy's simplifying modulo. It reports a multiple only when sympy + proves the remainder is identically zero, so it never asserts an unsound + divisibility, and it leaves the 2.11 behaviour unchanged because the + original test already succeeds there. A sentinel attribute on the class + ensures the patch is installed at most once. + """ + try: + import sympy + from torch._inductor.sizevars import ( + SizeVarAllocator, + ) + except Exception: + return + + if getattr(SizeVarAllocator, "_dp_divisibility_patched", False): + return + + original_known_multiple_of = SizeVarAllocator.statically_known_multiple_of + + def statically_known_multiple_of( + self: Any, numerator: Any, denominator: Any + ) -> bool: + if original_known_multiple_of(self, numerator, denominator): + return True + # Integer denominators use the structural divisibility path introduced + # in 2.12, which is unaffected by the regression and needs no retry. + if isinstance(denominator, (int, sympy.Integer)): + return False + try: + num = sympy.sympify(numerator) + den = sympy.sympify(denominator) + # The bound mirrors Inductor's own guard against the cost of + # symbolic evaluation on expressions with many free symbols. + if len(num.free_symbols) > 20: + return False + # sympy's modulo factors the numerator, so (32*s + 64) % (s + 2) + # reduces to 0 and the divisibility is proven. + return bool(self.statically_known_true(sympy.Eq(num % den, 0))) + except Exception: + return False + + statically_known_multiple_of.__doc__ = original_known_multiple_of.__doc__ + SizeVarAllocator.statically_known_multiple_of = statically_known_multiple_of + SizeVarAllocator._dp_divisibility_patched = True diff --git a/source/tests/pt/model/test_sezm_model.py b/source/tests/pt/model/test_sezm_model.py index 5abbf55173..0396b24ad2 100644 --- a/source/tests/pt/model/test_sezm_model.py +++ b/source/tests/pt/model/test_sezm_model.py @@ -57,15 +57,17 @@ module=r"torch\._functorch\._aot_autograd\.autograd_cache", ) -# TODO(torch-2.11): SeZM's ``torch.compile`` / AOT-export code paths are only -# stable on torch 2.11.x. CI currently pins torch 2.10, where the compiled path -# can segfault or drift, and other torch versions are similarly unstable. Skip -# the compile-parity tests off 2.11 until CI standardizes on a SeZM-compatible -# torch, then drop this guard. +# SeZM's ``torch.compile`` / AOT-export code paths are validated on torch +# 2.11.x and 2.12.x, the releases the compile pipeline supports (see +# ``deepmd.pt.utils.compile_compat``). Other torch versions can segfault or +# drift, so the compile-parity tests are skipped there. _TORCH_VERSION = parse_version(torch.__version__) -_SKIP_OFF_TORCH_211 = (_TORCH_VERSION.major, _TORCH_VERSION.minor) != (2, 11) -_SKIP_OFF_TORCH_211_REASON = ( - "SeZM's torch.compile path is only stable on torch 2.11.x; " +_SKIP_OFF_COMPILE_TORCH = (_TORCH_VERSION.major, _TORCH_VERSION.minor) not in { + (2, 11), + (2, 12), +} +_SKIP_OFF_COMPILE_TORCH_REASON = ( + "SeZM's torch.compile path is only supported on torch 2.11.x and 2.12.x; " f"current torch is {torch.__version__}." ) @@ -386,7 +388,7 @@ def _train_steps( name: param.detach().clone() for name, param in model.named_parameters() } - @unittest.skipIf(_SKIP_OFF_TORCH_211, _SKIP_OFF_TORCH_211_REASON) + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) def test_compile_cache_slots_and_eval_shape_change(self) -> None: """Compile cache slots should coexist while eval handles batch-size growth.""" coord_1, atype_1, box_1, _, _, _ = self._make_tiny_frame() @@ -483,7 +485,7 @@ def test_compile_cache_slots_and_eval_shape_change(self) -> None: model_cmp.compiled_core_compute_cache[eval_key], callable_eval_first ) - @unittest.skipIf(_SKIP_OFF_TORCH_211, _SKIP_OFF_TORCH_211_REASON) + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) def test_charge_spin_condition_matches_compile(self) -> None: """Charge/spin conditions should work through the compiled energy path.""" coord, atype, box, _, _, _ = self._make_tiny_frame() @@ -625,7 +627,7 @@ def test_fixed_edge_geometry_matches_standard_cache(self) -> None: torch.testing.assert_close(cache_std.D_full, cache_sparse.D_full[:n_real]) torch.testing.assert_close(cache_std.Dt_full, cache_sparse.Dt_full[:n_real]) - @unittest.skipIf(_SKIP_OFF_TORCH_211, _SKIP_OFF_TORCH_211_REASON) + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) def test_eval_compile_policy(self) -> None: """Eval should stay eager by default and compile only with env override.""" model = get_sezm_model(self._build_model_params(use_compile=True)) @@ -642,7 +644,7 @@ def test_eval_compile_policy(self) -> None: model_eval.eval() self.assertTrue(model_eval.should_use_compile()) - @unittest.skipIf(_SKIP_OFF_TORCH_211, _SKIP_OFF_TORCH_211_REASON) + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) def test_forward_backward_double_backward_matches_compile(self) -> None: """ Check forward, backward, double backward, and short training consistency. @@ -927,7 +929,7 @@ def _build_wrapper(use_compile: bool) -> ModelWrapper: torch.allclose(out_e1["energy"], out_e2["energy"], atol=1.0e-8) ) - @unittest.skipIf(_SKIP_OFF_TORCH_211, _SKIP_OFF_TORCH_211_REASON) + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) def test_multitask_compile_matches_eager(self) -> None: """Legacy case embedding concatenation should match through compile.""" self._assert_multitask_compile_matches_eager(case_film_embd=False) @@ -1901,7 +1903,7 @@ def _build_matched_lora_models() -> tuple[SeZMModel, SeZMModel]: model_compile.load_state_dict(model_eager.state_dict()) return model_eager, model_compile - @unittest.skipIf(_SKIP_OFF_TORCH_211, _SKIP_OFF_TORCH_211_REASON) + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) def test_forward_and_backward_match_eager(self) -> None: """Forward / first-order / second-order outputs agree with eager.""" coord, atype, box = self._tiny_system() From f01dbfae5c5647594898033544c0ef332f99a713 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 8 Jun 2026 20:53:16 +0800 Subject: [PATCH 07/18] fix --- .../pt/model/descriptor/sezm_nn/grid_net.py | 7 ++- deepmd/pt/model/model/sezm_model.py | 27 +++++---- deepmd/pt/utils/nv_nlist.py | 58 +++++++++++++------ .../test_descriptor_sezm_grid_projection.py | 28 +++++++++ source/tests/pt/model/test_nlist_backend.py | 12 +++- source/tests/pt/model/test_nv_nlist.py | 39 ++++++++++--- 6 files changed, 130 insertions(+), 41 deletions(-) diff --git a/deepmd/pt/model/descriptor/sezm_nn/grid_net.py b/deepmd/pt/model/descriptor/sezm_nn/grid_net.py index 89d79d13d5..867dd47782 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/grid_net.py +++ b/deepmd/pt/model/descriptor/sezm_nn/grid_net.py @@ -40,6 +40,7 @@ SwiGLU, ) from .indexing import ( + build_l_major_index, build_m_major_l_index, map_degree_idx, ) @@ -69,7 +70,11 @@ def _build_frame_degree_index( if coefficient_layout == "m_major": return build_m_major_l_index(lmax, mmax, device=env.DEVICE) if coefficient_layout == "packed": - return map_degree_idx(lmax, device=env.DEVICE) + degree_index = map_degree_idx(lmax, device=env.DEVICE) + if int(mmax) == int(lmax): + return degree_index + coeff_index = build_l_major_index(lmax, mmax, device=env.DEVICE) + return degree_index.index_select(0, coeff_index) raise ValueError("`coefficient_layout` must be either 'packed' or 'm_major'") diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index f29ba50f06..977cc7f1aa 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -516,9 +516,11 @@ SeZMModel_ = make_model(SeZMAtomicModel) -# Local-atom count above which the O(N) Toolkit-Ops cell list replaces the dense -# all-pairs builder for periodic CUDA systems. +# Local-atom counts above which Toolkit-Ops replaces the dense all-pairs builder. +# Non-periodic systems switch when dense all-pairs transients become memory-heavy, +# even though the dense path remains slightly faster at medium sizes. SEZM_NV_NLIST_THRESHOLD = 1024 +SEZM_NV_NONPERIODIC_NLIST_THRESHOLD = 2048 # Apply the process-global PyTorch workarounds the compile pipeline relies on # (autotune log suppression, DDP optimiser, and the 2.12 divisibility repair) @@ -2251,9 +2253,10 @@ def build_neighbor_list( Used when the model constructs its own neighbor list from ``coord`` / ``box``, as opposed to ``forward_lower`` which receives an externally built nlist (e.g. from LAMMPS or an inference ``NeighborList`` strategy). - Large periodic CUDA systems use the O(N) Toolkit-Ops cell list + Large CUDA systems use the Toolkit-Ops neighbor list (:class:`NvNeighborList`); all other cases use the dense all-pairs - builder. Either way the neighbor list is trimmed to ``sum(sel)``. + builder. The non-periodic Toolkit-Ops path uses a larger threshold + because the dense builder is still faster at small sizes. Parameters ---------- @@ -2270,14 +2273,14 @@ def build_neighbor_list( Extended coordinates, extended atom types, mapping, and neighbor list. """ nloc = atype.shape[1] - if ( - box is not None - and coord.is_cuda - and nloc >= SEZM_NV_NLIST_THRESHOLD - and is_nv_available() - ): - # Large periodic systems: the device-resident O(N) Toolkit-Ops cell - # list avoids the dense all-pairs ghost expansion. It already keeps + nv_threshold = ( + SEZM_NV_NLIST_THRESHOLD + if box is not None + else SEZM_NV_NONPERIODIC_NLIST_THRESHOLD + ) + if coord.is_cuda and nloc >= nv_threshold and is_nv_available(): + # Large systems: the device-resident Toolkit-Ops neighbor list avoids + # the dense all-pairs ghost expansion. It already keeps # the nearest sum(sel) neighbors (fixed width, like the standard # builder); only its (nlist, mapping) order is swapped to this # method's (mapping, nlist) contract. diff --git a/deepmd/pt/utils/nv_nlist.py b/deepmd/pt/utils/nv_nlist.py index 460fcadbb1..74dd70a9c1 100644 --- a/deepmd/pt/utils/nv_nlist.py +++ b/deepmd/pt/utils/nv_nlist.py @@ -3,8 +3,7 @@ A :class:`~deepmd.dpmodel.utils.neighbor_list.NeighborList` implementation that builds the extended representation ``(extended_coord, extended_atype, nlist, -mapping)`` using the device-resident O(N) cell list in ``nvalchemiops``, intended -for large periodic systems. +mapping)`` using the device-resident neighbor-list kernels in ``nvalchemiops``. Toolkit-Ops returns a dense ``[total_atoms, max_neighbors]`` neighbor matrix over the flattened batch. The matrix is converted to the DeePMD extended-atom contract @@ -19,6 +18,7 @@ annotations, ) +import logging from typing import ( Any, ) @@ -33,18 +33,22 @@ ) NV_CELL_LIST_THRESHOLD = 1024 +NV_NONPERIODIC_CELL_LIST_THRESHOLD = 4096 + +log = logging.getLogger(__name__) def is_nv_available() -> bool: """Whether the ``nvalchemiops`` Toolkit-Ops neighbor list is importable.""" try: import nvalchemiops.torch.neighbors # noqa: F401 - except ImportError: + except (ImportError, OSError, RuntimeError) as err: + log.debug("nvalchemiops Toolkit-Ops neighbor list is unavailable: %s", err) return False return True -def choose_nv_nlist_method(nloc: int) -> str: +def choose_nv_nlist_method(nloc: int, *, periodic: bool = True) -> str: """Choose the Toolkit-Ops neighbor method for a homogeneous batch. Parameters @@ -57,18 +61,21 @@ def choose_nv_nlist_method(nloc: int) -> str: str Toolkit-Ops method name. """ - if nloc >= NV_CELL_LIST_THRESHOLD: + threshold = ( + NV_CELL_LIST_THRESHOLD if periodic else NV_NONPERIODIC_CELL_LIST_THRESHOLD + ) + if nloc >= threshold: return "batch_cell_list" return "batch_naive" class NvNeighborList(NeighborList): - """O(N) neighbor-list strategy using the ``nvalchemiops`` cell list. + """Neighbor-list strategy using the ``nvalchemiops`` kernels. Implements the :class:`~deepmd.dpmodel.utils.neighbor_list.NeighborList` interface on torch tensors; the search runs on the device of the input - coordinates. A periodic ``box`` is required -- the cell list needs a cell to - wrap periodic images. + coordinates. Periodic inputs materialize shifted ghost atoms; non-periodic + inputs keep only local atoms. """ def build( @@ -82,35 +89,37 @@ def build( """Build the extended system and neighbor list. See :meth:`deepmd.dpmodel.utils.neighbor_list.NeighborList.build`. The - returned ``nlist`` is distance-sorted and truncated to ``sum(sel)``. A - periodic ``box`` is required, as the cell list operates on a periodic cell. + returned ``nlist`` is distance-sorted and truncated to ``sum(sel)``. """ from nvalchemiops.torch.neighbors import ( neighbor_list, ) - if box is None: - raise ValueError("NvNeighborList requires a periodic box; got box=None.") - nf, nloc = atype.shape[:2] device = coord.device target_neighbors = int(sum(sel)) search_capacity = target_neighbors total_atoms = nf * nloc - cell = box.reshape(nf, 3, 3).to(device=device, dtype=coord.dtype) - coord = normalize_coord(coord.reshape(nf, nloc, 3), cell) + coord = coord.reshape(nf, nloc, 3) + periodic = box is not None + if not periodic: + cell = None + pbc = None + else: + cell = box.reshape(nf, 3, 3).to(device=device, dtype=coord.dtype) + coord = normalize_coord(coord, cell) + pbc = torch.ones((nf, 3), dtype=torch.bool, device=device) positions_for_nlist = coord.reshape(total_atoms, 3).detach() - pbc = torch.ones((nf, 3), dtype=torch.bool, device=device) batch_idx = torch.arange( nf, dtype=torch.int32, device=device ).repeat_interleave(nloc) batch_ptr = torch.arange(nf + 1, dtype=torch.int32, device=device) * nloc - method = choose_nv_nlist_method(nloc) + method = choose_nv_nlist_method(nloc, periodic=periodic) # Grow the search capacity until all neighbors fit so the distance-sort # below selects the true nearest ``sum(sel)``. while True: - neighbor_matrix, num_neighbors, shifts = neighbor_list( + nlist_result = neighbor_list( positions_for_nlist, float(rcut), cell=cell, @@ -122,6 +131,15 @@ def build( return_neighbor_list=False, wrap_positions=False, ) + if len(nlist_result) == 2: + neighbor_matrix, num_neighbors = nlist_result + shifts = torch.zeros( + (*neighbor_matrix.shape, 3), + dtype=torch.int32, + device=device, + ) + else: + neighbor_matrix, num_neighbors, shifts = nlist_result max_found = ( int(num_neighbors.max().item()) if num_neighbors.numel() > 0 else 0 ) @@ -205,7 +223,7 @@ def _matrix_to_extended_inputs( *, coord: torch.Tensor, atype: torch.Tensor, - cell: torch.Tensor, + cell: torch.Tensor | None, nloc: int, neighbor_matrix: torch.Tensor, num_neighbors: torch.Tensor, @@ -272,6 +290,8 @@ def _matrix_to_extended_inputs( shifted_edge_idx = torch.nonzero(~zero_shift, as_tuple=False).flatten() if shifted_edge_idx.numel() == 0: return coord, atype, local_mapping, nlist + if cell is None: + raise RuntimeError("Non-periodic Toolkit-Ops neighbor list returned shifts.") # === Step 3. Materialize each unique shifted atom once per frame === # A shifted source may appear in many center atoms' neighbor slots. Dedup by diff --git a/source/tests/pt/model/test_descriptor_sezm_grid_projection.py b/source/tests/pt/model/test_descriptor_sezm_grid_projection.py index 602d3d47a6..4d10fa7f42 100644 --- a/source/tests/pt/model/test_descriptor_sezm_grid_projection.py +++ b/source/tests/pt/model/test_descriptor_sezm_grid_projection.py @@ -873,6 +873,34 @@ def test_kmax_two_quadratic_grid_ops_are_equivariant(self) -> None: rtol=1e-12, ) + def test_packed_truncated_cross_grid_net_forward(self) -> None: + torch.manual_seed(8500) + net = SO3GridNet( + lmax=3, + mmax=1, + kmax=1, + channels=2, + n_focus=1, + mode="cross", + op_type="glu", + dtype=torch.float64, + layout="ndfc", + coefficient_layout="packed", + trainable=False, + ).to(self.device) + coeff_dim = net.projector.coeff_dim // net.n_frames + query = torch.randn( + 2, + coeff_dim, + 1, + net.context_channels, + dtype=torch.float64, + device=self.device, + ) + context = torch.randn_like(query) + out = net(query, context) + self.assertEqual(out.shape, (2, coeff_dim, 1, net.output_channels)) + class TestSO3CounterExample(unittest.TestCase): def setUp(self) -> None: diff --git a/source/tests/pt/model/test_nlist_backend.py b/source/tests/pt/model/test_nlist_backend.py index a09127aa31..d38262dacd 100644 --- a/source/tests/pt/model/test_nlist_backend.py +++ b/source/tests/pt/model/test_nlist_backend.py @@ -198,8 +198,18 @@ def test_strategy_matches_native_multiframe(pt_files, backend: str, name: str) - @_BACKEND_MARKS["vesin"] @pytest.mark.parametrize("name", list(ALL_MODELS)) def test_vesin_matches_native_nonperiodic(pt_files, name: str) -> None: - """Vesin also supports non-periodic systems (nv requires a periodic box).""" + """Vesin also supports non-periodic systems.""" coords, atype, _ = _system() dp_native = DeepPot(pt_files[name], nlist_backend="native") dp_vesin = DeepPot(pt_files[name], nlist_backend="vesin") _assert_eval_close(dp_native, dp_vesin, coords, None, atype, f"{name} vesin nopbc") + + +@_BACKEND_MARKS["nv"] +@pytest.mark.parametrize("name", list(ALL_MODELS)) +def test_nv_matches_native_nonperiodic(pt_files, name: str) -> None: + """NV also supports non-periodic systems.""" + coords, atype, _ = _system() + dp_native = DeepPot(pt_files[name], nlist_backend="native") + dp_nv = DeepPot(pt_files[name], nlist_backend="nv") + _assert_eval_close(dp_native, dp_nv, coords, None, atype, f"{name} nv nopbc") diff --git a/source/tests/pt/model/test_nv_nlist.py b/source/tests/pt/model/test_nv_nlist.py index ede8dcc601..aba4908a43 100644 --- a/source/tests/pt/model/test_nv_nlist.py +++ b/source/tests/pt/model/test_nv_nlist.py @@ -4,8 +4,8 @@ These cover the builder paths the DeepEval end-to-end equivalence test (``test_nlist_backend.py``) cannot reach with its small ``batch_naive`` systems: the ``batch_cell_list`` method and the over-capacity distance-trim path, plus the -periodic-box requirement. Built neighbor lists are compared against the native -dense builder at the nlist level (edge topology + geometry). +non-periodic path. Built neighbor lists are compared against the native dense +builder at the nlist level (edge topology + geometry). """ import contextlib @@ -161,7 +161,7 @@ def _assert_nv_matches_native( self, coord: torch.Tensor, atype: torch.Tensor, - box: torch.Tensor, + box: torch.Tensor | None, rcut: float, sel: list[int], force_cell_list: bool = False, @@ -185,7 +185,10 @@ def _assert_nv_matches_native( ) with device_ctx: if force_cell_list: - with patch.object(nv_nlist, "NV_CELL_LIST_THRESHOLD", 1): + with ( + patch.object(nv_nlist, "NV_CELL_LIST_THRESHOLD", 1), + patch.object(nv_nlist, "NV_NONPERIODIC_CELL_LIST_THRESHOLD", 1), + ): nv = builder.build(coord, atype, box, rcut, sel) else: nv = builder.build(coord, atype, box, rcut, sel) @@ -241,10 +244,30 @@ def test_overfull_truncates_to_sel(self) -> None: force_cell_list=False, ) - def test_requires_periodic_box(self) -> None: - """The cell list needs a periodic box; ``box=None`` is rejected.""" + def test_nonperiodic_matches_native(self) -> None: + """Non-periodic systems keep local atoms and match the native builder.""" for device in _TEST_DEVICES: with self.subTest(device=str(device)): coord, atype, _ = self._build_case(1, device) - with self.assertRaises(ValueError): - NvNeighborList().build(coord, atype, None, 3.0, [8]) + self._assert_nv_matches_native( + coord=coord, + atype=atype, + box=None, + rcut=3.0, + sel=[8], + force_cell_list=False, + ) + + def test_nonperiodic_cell_list_matches_native(self) -> None: + """The no-box ``batch_cell_list`` path returns zero shifts.""" + for device in _TEST_DEVICES: + with self.subTest(device=str(device)): + coord, atype, _ = self._build_case(1, device) + self._assert_nv_matches_native( + coord=coord, + atype=atype, + box=None, + rcut=3.0, + sel=[8], + force_cell_list=True, + ) From d268b4e60f304c78a315f841334bfa95a50127a5 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 9 Jun 2026 14:15:29 +0800 Subject: [PATCH 08/18] fix cuda --- deepmd/pt/utils/nv_nlist.py | 141 ++++++++++++++----------- source/tests/pt/model/test_nv_nlist.py | 66 ++++++------ 2 files changed, 111 insertions(+), 96 deletions(-) diff --git a/deepmd/pt/utils/nv_nlist.py b/deepmd/pt/utils/nv_nlist.py index 74dd70a9c1..188d4f557d 100644 --- a/deepmd/pt/utils/nv_nlist.py +++ b/deepmd/pt/utils/nv_nlist.py @@ -18,8 +18,10 @@ annotations, ) +import contextlib import logging from typing import ( + TYPE_CHECKING, Any, ) @@ -37,6 +39,11 @@ log = logging.getLogger(__name__) +if TYPE_CHECKING: + from collections.abc import ( + Iterator, + ) + def is_nv_available() -> bool: """Whether the ``nvalchemiops`` Toolkit-Ops neighbor list is importable.""" @@ -69,6 +76,17 @@ def choose_nv_nlist_method(nloc: int, *, periodic: bool = True) -> str: return "batch_naive" +@contextlib.contextmanager +def _input_device_context(device: torch.device) -> Iterator[None]: + """Run third-party kernels with both default and current devices pinned.""" + if device.type == "cuda": + with torch.device(device), torch.cuda.device(device): + yield + else: + with torch.device(device): + yield + + class NvNeighborList(NeighborList): """Neighbor-list strategy using the ``nvalchemiops`` kernels. @@ -95,71 +113,72 @@ def build( neighbor_list, ) - nf, nloc = atype.shape[:2] device = coord.device - target_neighbors = int(sum(sel)) - search_capacity = target_neighbors - total_atoms = nf * nloc - coord = coord.reshape(nf, nloc, 3) - periodic = box is not None - if not periodic: - cell = None - pbc = None - else: - cell = box.reshape(nf, 3, 3).to(device=device, dtype=coord.dtype) - coord = normalize_coord(coord, cell) - pbc = torch.ones((nf, 3), dtype=torch.bool, device=device) - positions_for_nlist = coord.reshape(total_atoms, 3).detach() - batch_idx = torch.arange( - nf, dtype=torch.int32, device=device - ).repeat_interleave(nloc) - batch_ptr = torch.arange(nf + 1, dtype=torch.int32, device=device) * nloc - method = choose_nv_nlist_method(nloc, periodic=periodic) - - # Grow the search capacity until all neighbors fit so the distance-sort - # below selects the true nearest ``sum(sel)``. - while True: - nlist_result = neighbor_list( - positions_for_nlist, - float(rcut), + with _input_device_context(device): + nf, nloc = atype.shape[:2] + target_neighbors = int(sum(sel)) + search_capacity = target_neighbors + total_atoms = nf * nloc + coord = coord.reshape(nf, nloc, 3) + periodic = box is not None + if not periodic: + cell = None + pbc = None + else: + cell = box.reshape(nf, 3, 3).to(device=device, dtype=coord.dtype) + coord = normalize_coord(coord, cell) + pbc = torch.ones((nf, 3), dtype=torch.bool, device=device) + positions_for_nlist = coord.reshape(total_atoms, 3).detach() + batch_idx = torch.arange( + nf, dtype=torch.int32, device=device + ).repeat_interleave(nloc) + batch_ptr = torch.arange(nf + 1, dtype=torch.int32, device=device) * nloc + method = choose_nv_nlist_method(nloc, periodic=periodic) + + # Grow the search capacity until all neighbors fit so the distance-sort + # below selects the true nearest ``sum(sel)``. + while True: + nlist_result = neighbor_list( + positions_for_nlist, + float(rcut), + cell=cell, + pbc=pbc, + batch_idx=batch_idx, + batch_ptr=batch_ptr, + method=method, + max_neighbors=int(search_capacity), + return_neighbor_list=False, + wrap_positions=False, + ) + if len(nlist_result) == 2: + neighbor_matrix, num_neighbors = nlist_result + shifts = torch.zeros( + (*neighbor_matrix.shape, 3), + dtype=torch.int32, + device=device, + ) + else: + neighbor_matrix, num_neighbors, shifts = nlist_result + max_found = ( + int(num_neighbors.max().item()) if num_neighbors.numel() > 0 else 0 + ) + if max_found <= search_capacity: + break + search_capacity = max(max_found, _grow_search_capacity(search_capacity)) + + extended_coord, extended_atype, mapping, nlist = _matrix_to_extended_inputs( + coord=coord, + atype=atype, cell=cell, - pbc=pbc, - batch_idx=batch_idx, - batch_ptr=batch_ptr, - method=method, - max_neighbors=int(search_capacity), - return_neighbor_list=False, - wrap_positions=False, + nloc=nloc, + neighbor_matrix=neighbor_matrix, + num_neighbors=num_neighbors, + shifts=shifts, ) - if len(nlist_result) == 2: - neighbor_matrix, num_neighbors = nlist_result - shifts = torch.zeros( - (*neighbor_matrix.shape, 3), - dtype=torch.int32, - device=device, - ) - else: - neighbor_matrix, num_neighbors, shifts = nlist_result - max_found = ( - int(num_neighbors.max().item()) if num_neighbors.numel() > 0 else 0 + nlist = _truncate_to_sel_compiled( + extended_coord, nlist, target_neighbors, float(rcut) ) - if max_found <= search_capacity: - break - search_capacity = max(max_found, _grow_search_capacity(search_capacity)) - - extended_coord, extended_atype, mapping, nlist = _matrix_to_extended_inputs( - coord=coord, - atype=atype, - cell=cell, - nloc=nloc, - neighbor_matrix=neighbor_matrix, - num_neighbors=num_neighbors, - shifts=shifts, - ) - nlist = _truncate_to_sel_compiled( - extended_coord, nlist, target_neighbors, float(rcut) - ) - return extended_coord, extended_atype, nlist, mapping + return extended_coord, extended_atype, nlist, mapping def _grow_search_capacity(capacity: int) -> int: diff --git a/source/tests/pt/model/test_nv_nlist.py b/source/tests/pt/model/test_nv_nlist.py index aba4908a43..5a198cc06d 100644 --- a/source/tests/pt/model/test_nv_nlist.py +++ b/source/tests/pt/model/test_nv_nlist.py @@ -8,7 +8,6 @@ builder at the nlist level (edge topology + geometry). """ -import contextlib import unittest from unittest.mock import ( patch, @@ -24,6 +23,7 @@ ) from deepmd.pt.utils.nv_nlist import ( NvNeighborList, + _input_device_context, ) _NV_AVAILABLE = nv_nlist.is_nv_available() @@ -166,24 +166,18 @@ def _assert_nv_matches_native( sel: list[int], force_cell_list: bool = False, ) -> None: - # native: (extended_coord, extended_atype, mapping, nlist) - native = extend_input_and_build_neighbor_list( - coord, - atype, - rcut, - sel, - mixed_types=True, - box=box, - ) - # NeighborList strategy: (extended_coord, extended_atype, nlist, mapping) - builder = NvNeighborList() - # Pin the current CUDA device so the Toolkit-Ops backend launches there. - device_ctx = ( - torch.cuda.device(coord.device) - if coord.is_cuda - else contextlib.nullcontext() - ) - with device_ctx: + with _input_device_context(coord.device): + # native: (extended_coord, extended_atype, mapping, nlist) + native = extend_input_and_build_neighbor_list( + coord, + atype, + rcut, + sel, + mixed_types=True, + box=box, + ) + # NeighborList strategy: (extended_coord, extended_atype, nlist, mapping) + builder = NvNeighborList() if force_cell_list: with ( patch.object(nv_nlist, "NV_CELL_LIST_THRESHOLD", 1), @@ -192,23 +186,25 @@ def _assert_nv_matches_native( nv = builder.build(coord, atype, box, rcut, sel) else: nv = builder.build(coord, atype, box, rcut, sel) - native_coord, _, native_mapping, native_nlist = native - nv_coord, nv_atype, nv_nlist_out, nv_mapping = nv - # The strategy trims to sum(sel) itself, so the width is fixed. - self.assertEqual(nv_nlist_out.shape[-1], sum(sel)) - self.assertTrue( - torch.equal( - _edge_topology_from_extended(native_mapping, native_nlist), - _edge_topology_from_extended(nv_mapping, nv_nlist_out), + native_coord, _, native_mapping, native_nlist = native + nv_coord, nv_atype, nv_nlist_out, nv_mapping = nv + # The strategy trims to sum(sel) itself, so the width is fixed. + self.assertEqual(nv_nlist_out.shape[-1], sum(sel)) + self.assertTrue( + torch.equal( + _edge_topology_from_extended(native_mapping, native_nlist), + _edge_topology_from_extended(nv_mapping, nv_nlist_out), + ) ) - ) - torch.testing.assert_close( - _edge_geometry_from_extended(native_coord, native_mapping, native_nlist), - _edge_geometry_from_extended(nv_coord, nv_mapping, nv_nlist_out), - atol=1.0e-10, - rtol=1.0e-10, - ) - _assert_extended_atype_matches_mapping(self, atype, nv_atype, nv_mapping) + torch.testing.assert_close( + _edge_geometry_from_extended( + native_coord, native_mapping, native_nlist + ), + _edge_geometry_from_extended(nv_coord, nv_mapping, nv_nlist_out), + atol=1.0e-10, + rtol=1.0e-10, + ) + _assert_extended_atype_matches_mapping(self, atype, nv_atype, nv_mapping) def test_cell_list_matches_native(self) -> None: """The ``batch_cell_list`` method (forced via the threshold) matches the From f57fd33d3625cd65baafeb254ac54a6b4cd0a80f Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 9 Jun 2026 17:27:11 +0800 Subject: [PATCH 09/18] doc --- deepmd/pt/model/model/sezm_model.py | 7 ++++--- deepmd/utils/argcheck.py | 7 ++++++- doc/model/dpa4.md | 3 ++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index 977cc7f1aa..0cc283ea6e 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -2845,9 +2845,10 @@ def deserialize(cls, data: dict[str, Any]) -> SeZMModel: def tf32_precision_ctx(self) -> Generator[None, None, None]: """Context manager to temporarily set TF32 matmul precision. - Training follows ``enable_tf32``. Eval/inference follows - ``DP_TF32_INFER``: 0 keeps ``highest`` precision, 1 selects - ``high``, and 2 selects ``medium``. + Training follows ``enable_tf32`` independently of whether the current + forward uses the compile path. Eval/inference follows ``DP_TF32_INFER``: + 0 keeps ``highest`` precision, 1 selects ``high``, and 2 selects + ``medium``. """ if not torch.cuda.is_available(): yield diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 34cdf08ad1..19fbe8cebd 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3170,7 +3170,12 @@ def sezm_model_args() -> Argument: "Requires torch==2.11. NVIDIA GPUs require CUDA >= 12.6. " "Apple Silicon Macs are also supported. Tested with Python 3.13." ) - doc_enable_tf32 = "If True, enable TF32 matmul precision when use_compile=True." + doc_enable_tf32 = ( + "If True, enable TF32 matmul precision for CUDA training forwards. " + "This training-time setting is independent of `use_compile`; eval-time " + "TF32 is controlled separately by `validating.tf32_infer` or " + "`DP_TF32_INFER`." + ) doc_bridging_method = ( "Short-range bridging method. Currently supports 'ZBL'. " "The value is case-insensitive; set it to 'None' to disable bridging." diff --git a/doc/model/dpa4.md b/doc/model/dpa4.md index 63d0e4538c..96410295fb 100644 --- a/doc/model/dpa4.md +++ b/doc/model/dpa4.md @@ -456,7 +456,8 @@ During training validation, the input option `validating.tf32_infer: true` is translated into `DP_TF32_INFER=1` before model construction, again without overriding an explicitly exported environment variable. Training forwards are controlled separately by -`model.enable_tf32`. +`model.enable_tf32`, independently of whether `model.use_compile` selects the +compiled or eager training path. For molecular dynamics and other workflows that are sensitive to potential energy surface smoothness, keep `DP_TF32_INFER=0`. Enabling TF32 inference may From 0dc112bf79e53d6d6e0f64aeb75d63e38f568d64 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 9 Jun 2026 20:24:43 +0800 Subject: [PATCH 10/18] fix atype --- deepmd/pt/model/model/sezm_model.py | 2 ++ deepmd/pt/model/model/sezm_spin_model.py | 16 +++++----------- source/tests/pt/model/test_sezm_spin_model.py | 1 + 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index 0cc283ea6e..24271f6461 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -891,6 +891,7 @@ def forward_common( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam + atype = atype.to(device=cc.device, dtype=torch.long) nf, nloc = atype.shape[:2] if cc.ndim == 2: cc = cc.view(nf, nloc, 3) @@ -1562,6 +1563,7 @@ def forward_common_lower( cc_ext, _, fp, ap, input_prec = self._input_type_cast( extended_coord, fparam=fparam, aparam=aparam ) + extended_atype = extended_atype.to(device=cc_ext.device, dtype=torch.long) cc_ext = cc_ext.reshape(extended_atype.shape[0], -1, 3) if extended_coord_corr is not None and extended_coord_corr.ndim == 2: extended_coord_corr = extended_coord_corr.reshape( diff --git a/deepmd/pt/model/model/sezm_spin_model.py b/deepmd/pt/model/model/sezm_spin_model.py index c3ff931f9e..d36ee36fd9 100644 --- a/deepmd/pt/model/model/sezm_spin_model.py +++ b/deepmd/pt/model/model/sezm_spin_model.py @@ -34,9 +34,6 @@ SpinModel, _lookup_type_values, ) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, -) from deepmd.pt.utils.utils import ( to_torch_tensor, ) @@ -158,6 +155,7 @@ def forward_common( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam + atype = atype.to(device=cc.device, dtype=torch.long) nf, nloc = atype.shape[:2] if cc.ndim == 2: cc = cc.view(nf, nloc, 3) @@ -261,6 +259,9 @@ def forward_common_lower( charge_spin: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Return spin-aware lower-interface predictions with internal keys.""" + extended_atype = extended_atype.to( + device=extended_coord.device, dtype=torch.long + ) _, nloc = nlist.shape[:2] ( extended_coord_updated, @@ -539,14 +540,7 @@ def build_neighbor_list( box: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Build the real-atom neighbor list before spin expansion.""" - return extend_input_and_build_neighbor_list( - coord, - atype, - self.get_rcut(), - self.real_sel, - mixed_types=True, - box=box, - ) + return super().build_neighbor_list(coord, atype, box) def format_nlist( self, diff --git a/source/tests/pt/model/test_sezm_spin_model.py b/source/tests/pt/model/test_sezm_spin_model.py index 8bf2170c01..aeaec3f8df 100644 --- a/source/tests/pt/model/test_sezm_spin_model.py +++ b/source/tests/pt/model/test_sezm_spin_model.py @@ -224,6 +224,7 @@ def test_factory_shapes_and_masks(self) -> None: def test_forward_lower_matches_forward(self) -> None: """Lower spin interface should match the standard spin forward path.""" model = get_model(self._build_model_params()).to(self.device) + model.eval() out = model(self.coord, self.atype, spin=self.spin, box=self.box) extended_coord, extended_atype, mapping, nlist = ( extend_input_and_build_neighbor_list( From 27bc2c938dd98d7e8897ee83d52769f7907a9809 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 9 Jun 2026 20:46:46 +0800 Subject: [PATCH 11/18] private triton --- deepmd/pt/model/descriptor/sezm_nn/so2.py | 31 ++++---- .../descriptor/sezm_nn/triton/__init__.py | 14 ++-- .../descriptor/sezm_nn/triton/so2_rotation.py | 70 ++++--------------- .../pt/model/test_descriptor_sezm_triton.py | 30 ++------ 4 files changed, 37 insertions(+), 108 deletions(-) diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py index 84bd43c966..90f83f5207 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/so2.py +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -958,7 +958,7 @@ def __init__( self._rotate_to_local_fn = None self._rotate_back_fn = None if self.use_triton_infer: - from .triton import ( + from .triton.so2_rotation import ( rotate_back_block, rotate_back_dense, rotate_to_local_block, @@ -966,11 +966,19 @@ def __init__( ) if self.mmax == 1: - self._rotate_to_local_fn = rotate_to_local_block - self._rotate_back_fn = rotate_back_block + self._rotate_to_local_fn = lambda x, src, wigner: rotate_to_local_block( + x, src, wigner, self.lmax + ) + self._rotate_back_fn = lambda x_local, wigner: rotate_back_block( + x_local, wigner, self.lmax + ) else: - self._rotate_to_local_fn = rotate_to_local_dense - self._rotate_back_fn = rotate_back_dense + self._rotate_to_local_fn = lambda x, src, wigner: rotate_to_local_dense( + x, src, wigner, self.coeff_index_m, self.ebed_dim_full + ) + self._rotate_back_fn = lambda x_local, wigner: rotate_back_dense( + x_local, wigner, self.coeff_index_m, self.ebed_dim_full + ) # === Step 1. Precompute coefficient indices for m-major reduced layout === coeff_index_m = build_m_major_index(self.lmax, self.mmax, device=self.device) @@ -1460,12 +1468,10 @@ def forward( # ``self._rotate_to_local_fn`` was bound in ``__init__`` (the # block kernel for the m-major ``mmax == 1`` layout, dense # otherwise). - x_local = self._rotate_to_local_fn( - x, src, D_full, self.coeff_index_m, self.ebed_dim_full - ) # (E, D_m, C_wide) + x_local = self._rotate_to_local_fn(x, src, D_full) # (E, D_m, C_wide) if self.node_wise_grid_product is not None: x_dst_local = self._rotate_to_local_fn( - x, dst, D_full, self.coeff_index_m, self.ebed_dim_full + x, dst, D_full ) # (E, D_m, C_wide) else: D_m_prime = project_D_to_m( @@ -1623,12 +1629,7 @@ def apply_bias_correction( with nvtx_range("SO2Conv/rotate_back"): Dt_full = edge_cache.Dt_full if self.use_triton_infer and not self.training: - x_message = self._rotate_back_fn( - x_local, - Dt_full, - self.coeff_index_m, - self.ebed_dim_full, - ) # (E, D, C_wide) + x_message = self._rotate_back_fn(x_local, Dt_full) # (E, D, C_wide) else: Dt_from_m = project_Dt_from_m( Dt_full=Dt_full, diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py index 001956f9c1..86c95154ce 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py @@ -2,20 +2,14 @@ """Hardware-accelerated SeZM/DPA4 operators. This package hosts ``make_fx``-composable Triton implementations of SeZM hot -paths. The SO(2) rotation API exposes a general dense path that honors arbitrary -coefficient indices and a block path for the canonical m-major ``mmax=1`` layout. +paths. Kernel entry points are internal implementation details of the SeZM +descriptor; the package-level API only exposes availability. """ from .so2_rotation import ( - rotate_back_block, - rotate_back_dense, - rotate_to_local_block, - rotate_to_local_dense, + TRITON_ROTATION_AVAILABLE, ) __all__ = [ - "rotate_back_block", - "rotate_back_dense", - "rotate_to_local_block", - "rotate_to_local_dense", + "TRITON_ROTATION_AVAILABLE", ] diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py b/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py index d0e01694ff..1d14037266 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py @@ -58,8 +58,6 @@ annotations, ) -import math - import torch from torch import ( Tensor, @@ -71,12 +69,6 @@ __all__ = [ "TRITON_ROTATION_AVAILABLE", - "rotate_back_block", - "rotate_back_dense", - "rotate_back_reference", - "rotate_to_local_block", - "rotate_to_local_dense", - "rotate_to_local_reference", ] try: @@ -1159,30 +1151,8 @@ def _launch_rotate_back_bwd( # ====================================================================== -# Block-diagonal launch wrappers + layout detection (mmax == 1) +# Block-diagonal launch wrappers (mmax == 1) # ====================================================================== -def _block_layout_lmax(coeff_index: Tensor, dim_full: int) -> int: - """Return ``lmax`` if ``(coeff_index, dim_full)`` is the m-major ``mmax=1`` - layout that the block-diagonal kernels assume, else ``-1``. - - This intentionally checks only shape-level invariants. The block kernels - ignore ``coeff_index`` values, so production callers must only use the block - entry points when they own the canonical m-major ``mmax=1`` index. - """ - dim_full = int(dim_full) - root = math.isqrt(dim_full) - if root * root != dim_full: - return -1 - lmax = root - 1 - try: - numel = int(coeff_index.shape[0]) - except Exception: # pragma: no cover - exotic shape proxies - return -1 - if lmax < 1 or numel != 3 * lmax + 1: - return -1 - return lmax - - def _launch_bd_to_local_fwd( x: Tensor, src: Tensor, wigner: Tensor, lmax: int ) -> Tensor: @@ -1619,37 +1589,23 @@ def rotate_back_dense( return _rotate_back_op(x_local, wigner, coeff_index, int(dim_full)) -def rotate_to_local_block( - x: Tensor, src: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int -) -> Tensor: +def rotate_to_local_block(x: Tensor, src: Tensor, wigner: Tensor, lmax: int) -> Tensor: """Apply the block-diagonal ``global -> local`` rotation. - Use this only when the caller owns the invariant that ``coeff_index`` is the - canonical m-major ``mmax=1`` index produced by - :func:`build_m_major_index`. The kernel ignores the tensor values in - ``coeff_index`` and derives the layout from ``lmax``. + Use this when the caller owns the invariant that the reduced layout is the + canonical m-major ``mmax=1`` layout for ``lmax``. The block kernel derives + the reduced row order from ``lmax`` and does not consume a coefficient-index + tensor. """ - lmax = _block_layout_lmax(coeff_index, dim_full) - if lmax < 0: - raise ValueError( - "rotate_to_local_block requires the m-major mmax=1 coefficient layout." - ) - return _block_to_local_op(x, src, wigner, lmax) + return _block_to_local_op(x, src, wigner, int(lmax)) -def rotate_back_block( - x_local: Tensor, wigner: Tensor, coeff_index: Tensor, dim_full: int -) -> Tensor: +def rotate_back_block(x_local: Tensor, wigner: Tensor, lmax: int) -> Tensor: """Apply the block-diagonal ``local -> global`` rotation. - Use this only when the caller owns the invariant that ``coeff_index`` is the - canonical m-major ``mmax=1`` index produced by - :func:`build_m_major_index`. The kernel ignores the tensor values in - ``coeff_index`` and derives the layout from ``lmax``. + Use this when the caller owns the invariant that ``x_local`` is ordered in + the canonical m-major ``mmax=1`` layout for ``lmax``. The block kernel + derives the reduced column order from ``lmax`` and does not consume a + coefficient-index tensor. """ - lmax = _block_layout_lmax(coeff_index, dim_full) - if lmax < 0: - raise ValueError( - "rotate_back_block requires the m-major mmax=1 coefficient layout." - ) - return _block_back_op(x_local, wigner, lmax) + return _block_back_op(x_local, wigner, int(lmax)) diff --git a/source/tests/pt/model/test_descriptor_sezm_triton.py b/source/tests/pt/model/test_descriptor_sezm_triton.py index 05c88216a4..bcf1b57e3b 100644 --- a/source/tests/pt/model/test_descriptor_sezm_triton.py +++ b/source/tests/pt/model/test_descriptor_sezm_triton.py @@ -94,28 +94,6 @@ def test_noncanonical_same_length_uses_dense_reference(self): rotate_back_reference(x_local, wigner, coeff_index, dim), ) - def test_explicit_block_uses_shape_contract_only(self): - device = torch.device("cpu") - dtype = torch.float32 - lmax = 3 - dim = get_so3_dim_of_lmax(lmax) - canonical = build_m_major_index(lmax, 1, device=device) - coeff_index = torch.roll(canonical, shifts=1) - x = torch.randn(4, dim, 3, device=device, dtype=dtype) - src = torch.tensor([0, 2, 1, 3, 0], dtype=torch.long, device=device) - wigner = torch.randn(src.numel(), dim, dim, device=device, dtype=dtype) - x_local = torch.randn( - src.numel(), coeff_index.numel(), 3, device=device, dtype=dtype - ) - - self.assertEqual( - rotate_to_local_block(x, src, wigner, coeff_index, dim).shape, x_local.shape - ) - self.assertEqual( - rotate_back_block(x_local, wigner, coeff_index, dim).shape, - (src.numel(), dim, 3), - ) - def test_symbolic_trace_noncanonical_same_length_uses_dense_op(self): device = torch.device("cpu") dtype = torch.float32 @@ -186,7 +164,7 @@ def _assert_to_local_matches_reference(self, x0, src, w0, coeff_index, dim): xa = x0.clone().requires_grad_(True) wa = w0.clone().requires_grad_(True) - out = rotate_to_local_block(xa, src, wa, coeff_index, dim) + out = rotate_to_local_block(xa, src, wa, lmax) xr = x0.clone().requires_grad_(True) wr = w0.clone().requires_grad_(True) ref = rotate_to_local_reference(xr, src, wr, coeff_index, dim) @@ -206,7 +184,7 @@ def _assert_back_matches_reference(self, xl0, w0, coeff_index, dim): xa = xl0.clone().requires_grad_(True) wa = w0.clone().requires_grad_(True) - out = rotate_back_block(xa, wa, coeff_index, dim) + out = rotate_back_block(xa, wa, lmax) xr = xl0.clone().requires_grad_(True) wr = w0.clone().requires_grad_(True) ref = rotate_back_reference(xr, wr, coeff_index, dim) @@ -248,7 +226,7 @@ def test_symbolic_make_fx_rotate_to_local_forward_backward_matches_eager(self): def forward_and_grad(x, wigner): x_req = x.detach().requires_grad_(True) w_req = wigner.detach().requires_grad_(True) - out = rotate_to_local_block(x_req, src, w_req, coeff_index, dim) + out = rotate_to_local_block(x_req, src, w_req, lmax) grad_x, grad_wigner = torch.autograd.grad( out, (x_req, w_req), @@ -291,7 +269,7 @@ def test_symbolic_make_fx_rotate_back_forward_backward_matches_eager(self): def forward_and_grad(x_local, wigner): x_req = x_local.detach().requires_grad_(True) w_req = wigner.detach().requires_grad_(True) - out = rotate_back_block(x_req, w_req, coeff_index, dim) + out = rotate_back_block(x_req, w_req, lmax) grad_x, grad_wigner = torch.autograd.grad( out, (x_req, w_req), From eb546c755a3be946d4b8a8c9865bde1a5623209b Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 9 Jun 2026 23:14:06 +0800 Subject: [PATCH 12/18] disable atomic viral when freezing --- deepmd/pt/entrypoints/freeze_pt2.py | 34 +++++++++++++++++++---------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/deepmd/pt/entrypoints/freeze_pt2.py b/deepmd/pt/entrypoints/freeze_pt2.py index c1c36b4fb7..7adce22d1d 100644 --- a/deepmd/pt/entrypoints/freeze_pt2.py +++ b/deepmd/pt/entrypoints/freeze_pt2.py @@ -49,6 +49,9 @@ from deepmd.pt.train.wrapper import ( ModelWrapper, ) +from deepmd.pt.utils.compile_compat import ( + build_inductor_compile_options, +) from deepmd.pt.utils.env import ( DEVICE, ) @@ -219,6 +222,7 @@ def _collect_metadata( model: torch.nn.Module, output_keys: list[str], is_spin: bool | None = None, + do_atomic_virial: bool = False, ) -> dict: """Assemble the flat metadata dict expected by :class:`DeepPotPTExpt`. @@ -261,6 +265,8 @@ def _collect_metadata( "mixed_types": bool(model.mixed_types()), "has_message_passing": _model_has_message_passing(model), "has_comm_artifact": False, + "do_atomic_virial": bool(do_atomic_virial), + "nnei": int(sum(model.get_sel())), "has_default_fparam": bool(model.has_default_fparam()), "default_fparam": _to_py_list(model.get_default_fparam()), "default_chg_spin": _to_py_list(model.get_default_chg_spin()), @@ -468,6 +474,7 @@ def freeze_sezm_to_pt2( *, device: torch.device | None = None, head: str | None = None, + atomic_virial: bool = False, ) -> None: """Freeze a SeZM checkpoint into an AOTInductor ``.pt2`` archive. @@ -484,6 +491,9 @@ def freeze_sezm_to_pt2( Model head to export from a multi-task checkpoint. If omitted, the ``Default`` head is used when present; otherwise multi-task checkpoints must pass an explicit head. Single-task checkpoints must pass ``None``. + atomic_virial + Whether to include per-atom virial outputs in the exported graph. + Disable this for fastest LAMMPS force/energy/total-virial inference. """ from torch._inductor import ( aoti_compile_and_package, @@ -515,9 +525,6 @@ def freeze_sezm_to_pt2( has_spin=is_spin, ) - # do_atomic_virial=True pulls every key that DeepPotPTExpt may read - # (energy, energy_redu, energy_derv_r, energy_derv_c, energy_derv_c_redu) - # into the traced graph. if is_spin: ( ext_coord, @@ -538,7 +545,7 @@ def freeze_sezm_to_pt2( fparam=fparam, aparam=aparam, charge_spin=charge_spin, - do_atomic_virial=True, + do_atomic_virial=atomic_virial, ) else: ( @@ -558,7 +565,7 @@ def freeze_sezm_to_pt2( fparam=fparam, aparam=aparam, charge_spin=charge_spin, - do_atomic_virial=True, + do_atomic_virial=atomic_virial, ) # Output key order is taken from a concrete run; Python dict order @@ -588,14 +595,19 @@ def freeze_sezm_to_pt2( exported = move_to_device_pass(exported, target_device) out_path_str = str(out_path) - # Match the runtime eval compile path's Inductor option: triton.max_tiles=1 - # keeps pointwise grids 1D so the data-dependent compact-edge axis stays on - # Triton's x grid (limit 2**31); the default tiling places it on the y/z - # grid (limit 65535), which overflows for large systems. - with inductor_config.patch({"triton.max_tiles": 1}): + compile_options = build_inductor_compile_options() + # Keep AOTInductor aligned with the eval compile path. ``triton.max_tiles=1`` + # keeps data-dependent edge axes on Triton's x grid, whose bound is large + # enough for production-scale neighbor lists. + with inductor_config.patch({**compile_options, "triton.max_tiles": 1}): aoti_compile_and_package(exported, package_path=out_path_str) - metadata = _collect_metadata(model, output_keys=output_keys, is_spin=is_spin) + metadata = _collect_metadata( + model, + output_keys=output_keys, + is_spin=is_spin, + do_atomic_virial=atomic_virial, + ) with zipfile.ZipFile(out_path_str, "a") as zf: zf.writestr("model/extra/metadata.json", json.dumps(metadata)) # The raw training params are preserved so `dp change-bias` and From 80c1c9f653af6b25e807c559995512eaa787a1e4 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 9 Jun 2026 23:20:17 +0800 Subject: [PATCH 13/18] doc --- doc/model/dpa4.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/doc/model/dpa4.md b/doc/model/dpa4.md index 96410295fb..78d4787a5d 100644 --- a/doc/model/dpa4.md +++ b/doc/model/dpa4.md @@ -480,6 +480,17 @@ to floating-point rounding. They retain full float32 accumulation regardless of workflows. They are compatible with the compile path (`DP_COMPILE_INFER=1`) and reduce both latency and peak memory. +When exporting DPA4/SeZM to `.pt2`, set inference environment variables before +running `dp --pt freeze`. The exported package is an AOTInductor artifact, so +graph-level choices and compiler precision settings are fixed during export and +are not re-evaluated when the `.pt2` file is later loaded by ASE or LAMMPS. +In particular, `DP_TRITON_INFER` selects the SO(2) rotation branch that is +captured into the exported graph, and `DP_TF32_INFER` should be set before +export if TF32 inference is desired. `DP_ACT_INFER` is not a runtime control for +`.pt2` inference: activation checkpointing is a Python/autograd memory-saving +strategy, while `.pt2` inference runs a forward-only AOTI package whose force +and virial computations have already been lowered into the exported graph. + ### Hardware selection DPA4/SeZM is designed for fp32 training and inference. Hardware selection From a6070576a1b48085707e2d9555f5ecd03242c063 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 10 Jun 2026 12:18:05 +0800 Subject: [PATCH 14/18] feat: add triton radial mixing --- deepmd/pt/model/descriptor/sezm_nn/so2.py | 152 +++- .../descriptor/sezm_nn/triton/__init__.py | 10 +- .../descriptor/sezm_nn/triton/radial_mix.py | 693 ++++++++++++++++++ deepmd/pt/model/descriptor/sezm_nn/utils.py | 19 + .../pt/model/test_descriptor_sezm_triton.py | 148 +++- 5 files changed, 980 insertions(+), 42 deletions(-) create mode 100644 deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py index 90f83f5207..b949769c58 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/so2.py +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -11,7 +11,6 @@ ) import math -import os from typing import ( TYPE_CHECKING, Any, @@ -78,6 +77,7 @@ np_safe, nvtx_range, safe_numpy_to_tensor, + use_triton_infer, ) if TYPE_CHECKING: @@ -326,6 +326,71 @@ def __init__( # Invalidated on train() via overridden method below. self._cached_weight: torch.Tensor | None = None + # The assembled SO(2) weight is block-diagonal over |m| groups; the + # forward contracts only the diagonal blocks (see _block_diagonal_matmul). + # Each |m| group occupies a contiguous (in, out) block on the diagonal. + self._block_diag_slices = self._build_block_diag_slices() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input with shape (E, F, D_m_trunc, Cin), where D_m_trunc is the + coefficient dimension of the m-major layout truncated by `mmax`. + + Returns + ------- + torch.Tensor + Output with shape (E, F, D_m_trunc, Cout), where Cout is output channels. + """ + # === Step 1. Flatten coefficient + channel axes for matmul === + # (E, F, D_m, Cin) -> (E, F, D_m*Cin) + n_edge = x.shape[0] + in_dim_total = self.reduced_dim * self.in_channels + x_flat = x.reshape(n_edge, self.n_focus, in_dim_total) + + # === Step 2. Get block-diagonal weight (cached in eval+no_grad) === + if self._cached_weight is not None: + weight = self._cached_weight + else: + weight = self._build_so2_weight() + # Cache only in eval mode with grad disabled (pure inference). + if not self.training and not torch.is_grad_enabled(): + self._cached_weight = weight.detach() + + # === Step 3. Block-diagonal matmul over focus streams + reshape back === + out_flat = self._block_diagonal_matmul(x_flat, weight) + out = out_flat.reshape( + n_edge, self.n_focus, self.reduced_dim, self.out_channels + ) + + # === Step 4. Bias on l=0 scalar index === + if self.mlp_bias: + bias0 = self.bias0.view(self.n_focus, self.out_channels) + out[:, :, 0, :] = out[:, :, 0, :] + bias0.unsqueeze(0) + return out + + def _build_block_diag_slices(self) -> list[tuple[int, int, int, int]]: + """Return the ``(in_start, in_end, out_start, out_end)`` diagonal blocks. + + One entry per ``|m|`` group in m-major order: ``m = 0`` spans + ``lmax + 1`` coefficients and each ``|m| > 0`` spans ``2 * (lmax - m + 1)`` + coefficients (negative and positive orders). + """ + group_sizes = [self.lmax + 1] + [ + 2 * (self.lmax - m + 1) for m in range(1, self.mmax + 1) + ] + slices: list[tuple[int, int, int, int]] = [] + in_off = out_off = 0 + for num in group_sizes: + in_width = num * self.in_channels + out_width = num * self.out_channels + slices.append((in_off, in_off + in_width, out_off, out_off + out_width)) + in_off += in_width + out_off += out_width + return slices + def train(self, mode: bool = True) -> SO2Linear: """Invalidate weight cache when switching to training mode.""" self._cached_weight = None @@ -401,46 +466,38 @@ def _build_so2_weight(self) -> torch.Tensor: weight[pi0:pi1, :, po0:po1] = w_u # pos_in -> pos_out return weight - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ + def _block_diagonal_matmul( + self, x_flat: torch.Tensor, weight: torch.Tensor + ) -> torch.Tensor: + """Contract only the diagonal ``|m|`` blocks of the assembled weight. + + ``weight`` is block-diagonal over ``|m|`` (cross-``|m|`` blocks are + exactly zero), so concatenating the per-group matmuls reproduces the + dense ``einsum`` over the full ``(D_m*Cin, D_m*Cout)`` matrix while + skipping the structural zeros. The result is fp32-equivalent to the + dense path up to the matmul reduction order. + Parameters ---------- - x - Input with shape (E, F, D_m_trunc, Cin), where D_m_trunc is the - coefficient dimension of the m-major layout truncated by `mmax`. + x_flat : torch.Tensor + Flattened input with shape ``(E, F, D_m*Cin)``. + weight : torch.Tensor + Assembled block-diagonal weight with shape ``(D_m*Cin, F, D_m*Cout)``. Returns ------- torch.Tensor - Output with shape (E, F, D_m_trunc, Cout), where Cout is output channels. + Flattened output with shape ``(E, F, D_m*Cout)``. """ - # === Step 1. Flatten coefficient + channel axes for matmul === - # (E, F, D_m, Cin) -> (E, F, D_m*Cin) - n_edge = x.shape[0] - in_dim_total = self.reduced_dim * self.in_channels - x_flat = x.reshape(n_edge, self.n_focus, in_dim_total) - - # === Step 2. Get block-diagonal weight (cached in eval+no_grad) === - if self._cached_weight is not None: - weight = self._cached_weight - else: - weight = self._build_so2_weight() - # Cache only in eval mode with grad disabled (pure inference). - if not self.training and not torch.is_grad_enabled(): - self._cached_weight = weight.detach() - - # === Step 3. Batched matmul over focus streams + reshape back === - # einsum "efi,ifo->efo": (E,F,D_m*Cin) x (D_m*Cin,F,D_m*Cout) -> (E,F,D_m*Cout) - out_flat = torch.einsum("efi,ifo->efo", x_flat, weight) - out = out_flat.reshape( - n_edge, self.n_focus, self.reduced_dim, self.out_channels - ) - - # === Step 4. Bias on l=0 scalar index === - if self.mlp_bias: - bias0 = self.bias0.view(self.n_focus, self.out_channels) - out[:, :, 0, :] = out[:, :, 0, :] + bias0.unsqueeze(0) - return out + blocks = [ + torch.einsum( + "efi,ifo->efo", + x_flat[:, :, in0:in1], + weight[in0:in1, :, out0:out1], + ) + for in0, in1, out0, out1 in self._block_diag_slices + ] + return torch.cat(blocks, dim=-1) def serialize(self) -> dict[str, Any]: trainable = all(p.requires_grad for p in self.parameters()) @@ -577,6 +634,23 @@ def __init__( for p in self.parameters(): p.requires_grad = trainable + # Inference fast path (opt-in via ``DP_TRITON_INFER``): a fused Triton + # kernel replaces the dense scatter and the tiny batched matmul of the + # ``degree_channel`` low-rank branch in the ``mmax == 1`` layout. + self.use_triton_infer = use_triton_infer() + self._radial_mix_block = None + if ( + self.use_triton_infer + and self.mode == "degree_channel" + and self.rank > 0 + and self.mmax == 1 + ): + from .triton.radial_mix import ( + radial_mix_block, + ) + + self._radial_mix_block = radial_mix_block + def _build_dense_scatter_indices(self) -> tuple[torch.Tensor, torch.Tensor]: compact_indices: list[int] = [] dense_indices: list[int] = [] @@ -669,6 +743,10 @@ def forward(self, x_local: torch.Tensor, radial_feat: torch.Tensor) -> torch.Ten compact = kernel_flat.view( x_local.shape[0], self.degree_kernel_size, self.rank ) + if self._radial_mix_block is not None and not self.training: + return self._radial_mix_block( + compact, x_local, self.channel_basis, self.lmax + ) kernel = self._scatter_rank_kernel(compact) mixed = torch.einsum("eoir,eic->eorc", kernel, x_local) channel_basis = self.channel_basis.view(1, 1, self.rank, self.channels) @@ -946,14 +1024,12 @@ def __init__( self.device = env.DEVICE self.precision = RESERVED_PRECISION_DICT[dtype] self.compute_dtype = get_promoted_dtype(self.dtype) - # Optional Triton rotation kernels for the SO(2) convolution, enabled by + # Optional Triton inference kernels for the SO(2) convolution, enabled by # ``DP_TRITON_INFER=1`` (default disabled, in which case the dense # ``bmm`` rotation is used). The flag is read once at construction so it # is a compile-time constant in the traced (``make_fx``) graph, and it # only takes effect during inference. - self.use_triton_infer = os.environ.get( - "DP_TRITON_INFER", "0" - ).strip().lower() in ("1", "true", "yes", "on") + self.use_triton_infer = use_triton_infer() # Triton rotation kernels: block for the mmax == 1 layout, dense otherwise. self._rotate_to_local_fn = None self._rotate_back_fn = None diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py index 86c95154ce..3cc27f40d4 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py @@ -6,10 +6,18 @@ descriptor; the package-level API only exposes availability. """ +from .radial_mix import ( + RADIAL_MIX_TRITON_AVAILABLE, +) from .so2_rotation import ( TRITON_ROTATION_AVAILABLE, ) +# Both kernel modules guard their ``@triton.jit`` definitions behind a ``triton`` +# import, so the two module-level checks are equivalent. Expose a single +# package-level availability flag. +TRITON_AVAILABLE = TRITON_ROTATION_AVAILABLE and RADIAL_MIX_TRITON_AVAILABLE + __all__ = [ - "TRITON_ROTATION_AVAILABLE", + "TRITON_AVAILABLE", ] diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py b/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py new file mode 100644 index 0000000000..9f34c41a69 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py @@ -0,0 +1,693 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN202 +"""Fused Triton dynamic radial degree mixer for the SeZM/DPA4 descriptor. + +This module provides a clean-room Triton implementation of the +``degree_channel`` branch of :class:`DynamicRadialDegreeMixer` for the +``mmax == 1`` reduced layout. The eager reference applies, per edge ``e`` and +output coefficient ``o``:: + + out[e, o, c] = sum_r channel_basis[r, c] * sum_i K_r[e, o, i] * x[e, i, c] + +where ``K_r`` is the edge-conditioned degree kernel obtained by scattering the +projected radial features ``compact`` into a ``(reduced_dim, reduced_dim)`` +matrix. ``K_r`` is block-diagonal over the ``|m|`` groups, so for +``mmax == 1`` only a ``(lmax+1) x (lmax+1)`` block (orders ``m = 0``) and two +identical ``lmax x lmax`` blocks (orders ``m = -1`` and ``m = +1``) are +non-zero. + +Design goals +------------ +1. **Skip the structural zeros and the dense scratch.** The eager path + materializes the dense kernel ``(E, reduced_dim, reduced_dim, rank)`` via a + scatter and then contracts it with a batched ``einsum``/``bmm`` whose matrices + are tiny (``reduced_dim <= 16``), which is inefficient on cuBLAS and wastes + roughly two thirds of the multiply-adds on off-block zeros. The kernel + instead reads ``compact`` directly and contracts only the structural + non-zeros, with the channel axis vectorized and one program per edge. +2. **Match eager fp32 accuracy.** Accumulation is in fp32, matching the smooth + potential-energy surface contract used throughout the SeZM descriptor. +3. **Compose with the SeZM ``make_fx`` lowering.** The forward and backward are + functional ``torch.library.custom_op`` instances (``mutates_args=()``) with + registered fake kernels and an autograd formula, so + ``make_fx(tracing_mode="symbolic")`` captures the energy path together with + the force autograd graph used by inference. + +Inference-only contract +----------------------- +The operator is opt-in through ``DP_TRITON_INFER`` and is only used in +evaluation, where the force is obtained from ``autograd.grad(energy, coord)``. +The backward therefore returns gradients with respect to ``compact`` and +``x_local`` (both of which carry a path to the coordinates) and ``None`` for +``channel_basis``, which is a parameter and never differentiated by the force +computation. +""" + +from __future__ import ( + annotations, +) + +import torch +from torch import ( + Tensor, +) + +__all__ = [ + "RADIAL_MIX_TRITON_AVAILABLE", + "radial_mix_block", + "radial_mix_reference", +] + +try: + import triton + import triton.language as tl + + RADIAL_MIX_TRITON_AVAILABLE = True +except ImportError: # pragma: no cover - exercised only without triton + RADIAL_MIX_TRITON_AVAILABLE = False + + +# ====================================================================== +# Eager reference / fallback implementation +# ====================================================================== +def _block_layout(lmax: int) -> list[tuple[int, int, int]]: + """Return ``(coeff_start, compact_start, num_l)`` for the ``mmax == 1`` blocks. + + The reduced m-major layout keeps, for each degree ``l``, the orders + ``m = 0`` (the leading ``lmax + 1`` coefficients) followed by ``m = -1`` and + ``m = +1`` (``lmax`` coefficients each). The degree kernel for the two + signed-``m`` blocks is shared, hence the identical ``compact_start``. + """ + num_l0 = lmax + 1 + return [ + (0, 0, num_l0), + (num_l0, num_l0 * num_l0, lmax), + (num_l0 + lmax, num_l0 * num_l0, lmax), + ] + + +def radial_mix_reference( + compact: Tensor, x_local: Tensor, channel_basis: Tensor, lmax: int +) -> Tensor: + """Eager ground truth for :func:`radial_mix_block`. + + Parameters + ---------- + compact : Tensor + Projected radial degree kernel with shape ``(E, degree_kernel_size, R)``. + x_local : Tensor + Edge-local reduced features with shape ``(E, reduced_dim, C)``. + channel_basis : Tensor + Per-rank channel basis with shape ``(R, C)``. + lmax : int + Maximum spherical-harmonic degree. + + Returns + ------- + Tensor + Mixed features with shape ``(E, reduced_dim, C)``. + """ + n_edge, reduced_dim, channels = x_local.shape + out = x_local.new_zeros(n_edge, reduced_dim, channels) + for coeff0, comp0, num_l in _block_layout(int(lmax)): + # K[e, o, i, r] = compact[e, comp0 + i * num_l + o, r] + block = compact[:, comp0 : comp0 + num_l * num_l, :].reshape( + n_edge, num_l, num_l, -1 + ) + block = block.permute(0, 2, 1, 3) # (E, o, i, R) + x_block = x_local[:, coeff0 : coeff0 + num_l, :] # (E, i, C) + inner = torch.einsum("eoir,eic->eocr", block, x_block) # (E, o, C, R) + out[:, coeff0 : coeff0 + num_l, :] = torch.einsum( + "eocr,rc->eoc", inner, channel_basis + ) + return out + + +def _radial_mix_backward_reference( + grad_out: Tensor, compact: Tensor, x_local: Tensor, channel_basis: Tensor, lmax: int +) -> tuple[Tensor, Tensor]: + """Eager backward returning ``(grad_compact, grad_x_local)`` via autograd.""" + with torch.enable_grad(): + compact_req = compact.detach().requires_grad_(True) + x_req = x_local.detach().requires_grad_(True) + out = radial_mix_reference(compact_req, x_req, channel_basis, lmax) + grad_compact, grad_x = torch.autograd.grad(out, [compact_req, x_req], grad_out) + return grad_compact, grad_x + + +# ====================================================================== +# Triton kernels (mmax == 1; LMAX and RANK are constexpr; channels vectorized) +# ====================================================================== +if RADIAL_MIX_TRITON_AVAILABLE: + # The per-edge work is tiny and memory-light, so only the warp count and + # pipeline depth are swept, keyed on the channel width. + _CONFIGS = [ + triton.Config({}, num_warps=1, num_stages=1), + triton.Config({}, num_warps=2, num_stages=1), + triton.Config({}, num_warps=4, num_stages=1), + triton.Config({}, num_warps=2, num_stages=2), + triton.Config({}, num_warps=4, num_stages=2), + ] + _KEY = ["channels"] + + @triton.jit + def _mix_fwd_block( + edge, + chan, + cmask, + x_ptr, + k_ptr, + cb_ptr, + out_ptr, + x_se, + x_sr, + x_sc, + k_se, + k_sk, + k_sr, + cb_sr, + cb_sc, + o_se, + o_sr, + o_sc, + COEFF0: tl.constexpr, + COMPACT0: tl.constexpr, + NUM_L: tl.constexpr, + RANK: tl.constexpr, + ): + """Contract one diagonal block: ``out[o] = sum_r cb[r] sum_i K_r[o,i] x[i]``.""" + for o in tl.static_range(0, NUM_L): + acc = tl.zeros(chan.shape, dtype=tl.float32) + for r in tl.static_range(0, RANK): + partial = tl.zeros(chan.shape, dtype=tl.float32) + for i in tl.static_range(0, NUM_L): + kval = tl.load( + k_ptr + + edge * k_se + + (COMPACT0 + i * NUM_L + o) * k_sk + + r * k_sr + ).to(tl.float32) + x_vec = tl.load( + x_ptr + edge * x_se + (COEFF0 + i) * x_sr + chan * x_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + partial += kval * x_vec + cb_vec = tl.load( + cb_ptr + r * cb_sr + chan * cb_sc, mask=cmask, other=0.0 + ).to(tl.float32) + acc += partial * cb_vec + tl.store( + out_ptr + edge * o_se + (COEFF0 + o) * o_sr + chan * o_sc, + acc.to(out_ptr.dtype.element_ty), + mask=cmask, + ) + + @triton.autotune(configs=_CONFIGS, key=_KEY) + @triton.jit + def _radial_mix_fwd_kernel( + x_ptr, + k_ptr, + cb_ptr, + out_ptr, + n_edge, + channels, + x_se, + x_sr, + x_sc, + k_se, + k_sk, + k_sr, + cb_sr, + cb_sc, + o_se, + o_sr, + o_sc, + LMAX: tl.constexpr, + RANK: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < channels + num_l0: tl.constexpr = LMAX + 1 + strides = ( + x_se, + x_sr, + x_sc, + k_se, + k_sk, + k_sr, + cb_sr, + cb_sc, + o_se, + o_sr, + o_sc, + ) + # m = 0 block, then the shared m = -1 and m = +1 blocks. + _mix_fwd_block( + edge, + chan, + cmask, + x_ptr, + k_ptr, + cb_ptr, + out_ptr, + *strides, + 0, + 0, + num_l0, + RANK, + ) + _mix_fwd_block( + edge, + chan, + cmask, + x_ptr, + k_ptr, + cb_ptr, + out_ptr, + *strides, + num_l0, + num_l0 * num_l0, + LMAX, + RANK, + ) + _mix_fwd_block( + edge, + chan, + cmask, + x_ptr, + k_ptr, + cb_ptr, + out_ptr, + *strides, + num_l0 + LMAX, + num_l0 * num_l0, + LMAX, + RANK, + ) + + @triton.jit + def _mix_bwd_block( + edge, + chan, + cmask, + go_ptr, + x_ptr, + k_ptr, + cb_ptr, + gx_ptr, + gk_ptr, + go_se, + go_sr, + go_sc, + x_se, + x_sr, + x_sc, + k_se, + k_sk, + k_sr, + cb_sr, + cb_sc, + gx_se, + gx_sr, + gx_sc, + gk_se, + gk_sk, + gk_sr, + COEFF0: tl.constexpr, + COMPACT0: tl.constexpr, + NUM_L: tl.constexpr, + RANK: tl.constexpr, + ): + """Backward of one diagonal block. + + ``grad_x[i] = sum_r cb[r] sum_o K_r[o,i] grad_out[o]`` and + ``grad_K_r[o,i] = sum_c cb[r,c] x[i,c] grad_out[o,c]``. Both accumulators + are scattered with ``atomic_add`` into the zero-initialized outputs: the + ``m = -1`` and ``m = +1`` blocks share the ``compact`` slots, so their + contributions must sum. + """ + for i in tl.static_range(0, NUM_L): + grad_x = tl.zeros(chan.shape, dtype=tl.float32) + for r in tl.static_range(0, RANK): + cb_vec = tl.load( + cb_ptr + r * cb_sr + chan * cb_sc, mask=cmask, other=0.0 + ).to(tl.float32) + partial = tl.zeros(chan.shape, dtype=tl.float32) + for o in tl.static_range(0, NUM_L): + kval = tl.load( + k_ptr + + edge * k_se + + (COMPACT0 + i * NUM_L + o) * k_sk + + r * k_sr + ).to(tl.float32) + go_vec = tl.load( + go_ptr + edge * go_se + (COEFF0 + o) * go_sr + chan * go_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + partial += kval * go_vec + grad_x += cb_vec * partial + tl.atomic_add( + gx_ptr + edge * gx_se + (COEFF0 + i) * gx_sr + chan * gx_sc, + grad_x, + mask=cmask, + ) + for o in tl.static_range(0, NUM_L): + go_vec = tl.load( + go_ptr + edge * go_se + (COEFF0 + o) * go_sr + chan * go_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + for i in tl.static_range(0, NUM_L): + x_vec = tl.load( + x_ptr + edge * x_se + (COEFF0 + i) * x_sr + chan * x_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + for r in tl.static_range(0, RANK): + cb_vec = tl.load( + cb_ptr + r * cb_sr + chan * cb_sc, mask=cmask, other=0.0 + ).to(tl.float32) + grad_k = tl.sum(tl.where(cmask, go_vec * x_vec * cb_vec, 0.0)) + tl.atomic_add( + gk_ptr + + edge * gk_se + + (COMPACT0 + i * NUM_L + o) * gk_sk + + r * gk_sr, + grad_k, + ) + + @triton.autotune(configs=_CONFIGS, key=_KEY, reset_to_zero=["gx_ptr", "gk_ptr"]) + @triton.jit + def _radial_mix_bwd_kernel( + go_ptr, + x_ptr, + k_ptr, + cb_ptr, + gx_ptr, + gk_ptr, + n_edge, + channels, + go_se, + go_sr, + go_sc, + x_se, + x_sr, + x_sc, + k_se, + k_sk, + k_sr, + cb_sr, + cb_sc, + gx_se, + gx_sr, + gx_sc, + gk_se, + gk_sk, + gk_sr, + LMAX: tl.constexpr, + RANK: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < channels + num_l0: tl.constexpr = LMAX + 1 + strides = ( + go_se, + go_sr, + go_sc, + x_se, + x_sr, + x_sc, + k_se, + k_sk, + k_sr, + cb_sr, + cb_sc, + gx_se, + gx_sr, + gx_sc, + gk_se, + gk_sk, + gk_sr, + ) + _mix_bwd_block( + edge, + chan, + cmask, + go_ptr, + x_ptr, + k_ptr, + cb_ptr, + gx_ptr, + gk_ptr, + *strides, + 0, + 0, + num_l0, + RANK, + ) + _mix_bwd_block( + edge, + chan, + cmask, + go_ptr, + x_ptr, + k_ptr, + cb_ptr, + gx_ptr, + gk_ptr, + *strides, + num_l0, + num_l0 * num_l0, + LMAX, + RANK, + ) + _mix_bwd_block( + edge, + chan, + cmask, + go_ptr, + x_ptr, + k_ptr, + cb_ptr, + gx_ptr, + gk_ptr, + *strides, + num_l0 + LMAX, + num_l0 * num_l0, + LMAX, + RANK, + ) + + +# ====================================================================== +# Triton launch wrappers +# ====================================================================== +def _tile_channels(channels: int) -> int: + """Smallest power-of-two channel tile of at least 16 covering ``channels``.""" + tile = 16 + while tile < int(channels): + tile *= 2 + return tile + + +def _launch_forward( + x_local: Tensor, compact: Tensor, channel_basis: Tensor, lmax: int +) -> Tensor: + n_edge, reduced_dim, channels = x_local.shape + rank = int(compact.shape[-1]) + out = torch.empty_like(x_local) + if n_edge == 0: + return out + _radial_mix_fwd_kernel[(n_edge,)]( + x_local, + compact, + channel_basis, + out, + n_edge, + channels, + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + compact.stride(0), + compact.stride(1), + compact.stride(2), + channel_basis.stride(0), + channel_basis.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + LMAX=int(lmax), + RANK=rank, + BLOCK_C=_tile_channels(channels), + ) + return out + + +def _launch_backward( + grad_out: Tensor, + x_local: Tensor, + compact: Tensor, + channel_basis: Tensor, + lmax: int, +) -> tuple[Tensor, Tensor]: + n_edge, reduced_dim, channels = x_local.shape + rank = int(compact.shape[-1]) + grad_x = torch.zeros_like(x_local) + grad_compact = torch.zeros_like(compact) + if n_edge == 0: + return grad_compact, grad_x + _radial_mix_bwd_kernel[(n_edge,)]( + grad_out.contiguous(), + x_local, + compact, + channel_basis, + grad_x, + grad_compact, + n_edge, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + compact.stride(0), + compact.stride(1), + compact.stride(2), + channel_basis.stride(0), + channel_basis.stride(1), + grad_x.stride(0), + grad_x.stride(1), + grad_x.stride(2), + grad_compact.stride(0), + grad_compact.stride(1), + grad_compact.stride(2), + LMAX=int(lmax), + RANK=rank, + BLOCK_C=_tile_channels(channels), + ) + return grad_compact, grad_x + + +# ====================================================================== +# Dispatch helpers (triton on CUDA float, eager otherwise) +# ====================================================================== +def _use_triton(tensor: Tensor) -> bool: + return ( + RADIAL_MIX_TRITON_AVAILABLE + and tensor.is_cuda + and tensor.dtype in (torch.float16, torch.bfloat16, torch.float32) + ) + + +def _forward_impl( + compact: Tensor, x_local: Tensor, channel_basis: Tensor, lmax: int +) -> Tensor: + if not _use_triton(x_local): + return radial_mix_reference(compact, x_local, channel_basis, lmax) + return _launch_forward( + x_local.contiguous(), + compact.contiguous(), + channel_basis.contiguous(), + int(lmax), + ) + + +def _backward_impl( + grad_out: Tensor, + compact: Tensor, + x_local: Tensor, + channel_basis: Tensor, + lmax: int, +) -> tuple[Tensor, Tensor]: + if not _use_triton(x_local): + return _radial_mix_backward_reference( + grad_out, compact, x_local, channel_basis, lmax + ) + return _launch_backward( + grad_out, + x_local.contiguous(), + compact.contiguous(), + channel_basis.contiguous(), + int(lmax), + ) + + +# ====================================================================== +# Functional custom ops + fake + autograd registration +# ====================================================================== +_radial_mix_op = torch.library.custom_op( + "sezm_triton::radial_mix_block", mutates_args=() +)(_forward_impl) + +_radial_mix_bwd_op = torch.library.custom_op( + "sezm_triton::radial_mix_block_bwd", mutates_args=() +)(_backward_impl) + + +@_radial_mix_op.register_fake +def _(compact, x_local, channel_basis, lmax): + return torch.empty_like(x_local) + + +@_radial_mix_bwd_op.register_fake +def _(grad_out, compact, x_local, channel_basis, lmax): + return torch.empty_like(compact), torch.empty_like(x_local) + + +def _radial_mix_setup_context(ctx, inputs, output): + compact, x_local, channel_basis, lmax = inputs + ctx.save_for_backward(compact, x_local, channel_basis) + ctx.lmax = lmax + + +def _radial_mix_backward(ctx, grad_out): + compact, x_local, channel_basis = ctx.saved_tensors + grad_compact, grad_x = _radial_mix_bwd_op( + grad_out, compact, x_local, channel_basis, ctx.lmax + ) + # ``channel_basis`` is a parameter; the inference force differentiates only + # w.r.t. coordinates, so its gradient is intentionally not produced. + return grad_compact, grad_x, None, None + + +_radial_mix_op.register_autograd( + _radial_mix_backward, setup_context=_radial_mix_setup_context +) + + +# ====================================================================== +# Public API +# ====================================================================== +def radial_mix_block( + compact: Tensor, x_local: Tensor, channel_basis: Tensor, lmax: int +) -> Tensor: + """Apply the block-diagonal dynamic radial degree mixer (``mmax == 1``). + + Computes the same operation as :func:`radial_mix_reference` while avoiding + the dense scattered kernel and the tiny batched matmul on CUDA. + + Parameters + ---------- + compact : Tensor + Projected radial degree kernel with shape ``(E, degree_kernel_size, R)``. + x_local : Tensor + Edge-local reduced features with shape ``(E, reduced_dim, C)``. + channel_basis : Tensor + Per-rank channel basis with shape ``(R, C)``. + lmax : int + Maximum spherical-harmonic degree. + + Returns + ------- + Tensor + Mixed features with shape ``(E, reduced_dim, C)``. + """ + return _radial_mix_op(compact, x_local, channel_basis, int(lmax)) diff --git a/deepmd/pt/model/descriptor/sezm_nn/utils.py b/deepmd/pt/model/descriptor/sezm_nn/utils.py index 7b3d347bec..0fc7a92b4c 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/utils.py +++ b/deepmd/pt/model/descriptor/sezm_nn/utils.py @@ -11,6 +11,7 @@ ) import math +import os from contextlib import ( contextmanager, ) @@ -34,6 +35,24 @@ ATTN_RES_MODES = ("none", "independent", "dependent") +_TRITON_INFER_TRUE = ("1", "true", "yes", "on") + + +def use_triton_infer() -> bool: + """Return whether the opt-in Triton inference kernels are enabled. + + The flag is controlled by the ``DP_TRITON_INFER`` environment variable and + is read at module construction time so that it becomes a compile-time + constant in the traced (``make_fx``) graph. It only takes effect during + inference; training always uses the dense reference path. + + Returns + ------- + bool + ``True`` when ``DP_TRITON_INFER`` is set to a truthy value. + """ + return os.environ.get("DP_TRITON_INFER", "0").strip().lower() in _TRITON_INFER_TRUE + def init_trunc_normal_fan_in_out( weight: torch.Tensor, diff --git a/source/tests/pt/model/test_descriptor_sezm_triton.py b/source/tests/pt/model/test_descriptor_sezm_triton.py index bcf1b57e3b..4497361223 100644 --- a/source/tests/pt/model/test_descriptor_sezm_triton.py +++ b/source/tests/pt/model/test_descriptor_sezm_triton.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""Unit tests for the block-diagonal Triton SO(2)/Wigner rotation kernels -(opt-in via ``DP_TRITON_INFER``). +"""Unit tests for the opt-in Triton inference kernels of the SeZM descriptor +(enabled via ``DP_TRITON_INFER``): the block-diagonal SO(2)/Wigner rotation and +the fused dynamic radial degree mixer. -Two properties are checked against the eager PyTorch reference: +For the rotation kernels two properties are checked against the eager PyTorch +reference: 1. Numerical correctness of ``rotate_to_local`` / ``rotate_back`` (forward and backward) across ``lmax`` 2-5 with ``mmax == 1`` -- the only layout the block @@ -29,6 +31,14 @@ build_m_major_index, get_so3_dim_of_lmax, ) +from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + DynamicRadialDegreeMixer, +) +from deepmd.pt.model.descriptor.sezm_nn.triton.radial_mix import ( + RADIAL_MIX_TRITON_AVAILABLE, + radial_mix_block, + radial_mix_reference, +) from deepmd.pt.model.descriptor.sezm_nn.triton.so2_rotation import ( TRITON_ROTATION_AVAILABLE, rotate_back_block, @@ -401,5 +411,137 @@ def forward_and_grad(x_local, wigner): self.assertGreater(float(grad_w_eager.abs().max()), 0.0) +@unittest.skipIf(not _CUDA, "CUDA is required for the Triton radial-mix kernel") +@unittest.skipIf(not RADIAL_MIX_TRITON_AVAILABLE, "Triton is not available") +class TestSeZMTritonRadialMix(unittest.TestCase): + """Fused dynamic radial degree mixer (``degree_channel``, ``mmax == 1``). + + The Triton kernel and its eager reference are checked against the production + scatter path of :class:`DynamicRadialDegreeMixer`, and the forward/backward + are checked for symbolic ``make_fx`` composability with the inference-force + autograd graph. + """ + + def setUp(self): + self.device = torch.device("cuda") + self.dtype = torch.float32 + self.n_edge, self.channels, self.rank = 4096, 64, 1 + self.tol = {"rtol": 2e-4, "atol": 2e-4} + + def _mixer(self, lmax): + return ( + DynamicRadialDegreeMixer( + lmax=lmax, + mmax=1, + channels=self.channels, + mode="degree_channel", + rank=self.rank, + dtype=self.dtype, + seed=1, + trainable=True, + ) + .to(self.device) + .eval() + ) + + def _inputs(self, mixer, seed): + gen = torch.Generator(device=self.device).manual_seed(seed) + x_local = torch.randn( + self.n_edge, + mixer.reduced_dim, + self.channels, + device=self.device, + dtype=self.dtype, + generator=gen, + ) + radial_feat = torch.randn( + self.n_edge, + mixer.reduced_dim, + self.channels, + device=self.device, + dtype=self.dtype, + generator=gen, + ) + compact = mixer._project_radial(radial_feat).view( + self.n_edge, mixer.degree_kernel_size, self.rank + ) + return x_local, radial_feat, compact + + def test_reference_matches_module_eager_path(self): + """The block-split reference reproduces the module's dense scatter path.""" + for lmax in (2, 3, 4, 5): + with self.subTest(lmax=lmax): + mixer = self._mixer(lmax) + # Force the dense scatter path regardless of the ambient flag. + mixer._radial_mix_block = None + x_local, radial_feat, compact = self._inputs(mixer, seed=lmax) + with torch.no_grad(): + module_out = mixer(x_local, radial_feat) + ref_out = radial_mix_reference( + compact, x_local, mixer.channel_basis, lmax + ) + torch.testing.assert_close(ref_out, module_out, **self.tol) + + def test_triton_forward_matches_reference(self): + for lmax in (2, 3, 4, 5): + with self.subTest(lmax=lmax): + mixer = self._mixer(lmax) + x_local, _, compact = self._inputs(mixer, seed=lmax) + with torch.no_grad(): + out = radial_mix_block(compact, x_local, mixer.channel_basis, lmax) + ref = radial_mix_reference( + compact, x_local, mixer.channel_basis, lmax + ) + torch.testing.assert_close(out, ref, **self.tol) + + def test_triton_backward_matches_reference(self): + """Backward correctness on a fresh first call (checks reset_to_zero).""" + for lmax in (2, 3, 4, 5): + with self.subTest(lmax=lmax): + mixer = self._mixer(lmax) + x_local, _, compact = self._inputs(mixer, seed=lmax) + grad_out = torch.randn_like(x_local) + + ca = compact.detach().requires_grad_(True) + xa = x_local.detach().requires_grad_(True) + out = radial_mix_block(ca, xa, mixer.channel_basis, lmax) + grad_ca, grad_xa = torch.autograd.grad(out, [ca, xa], grad_out) + + cr = compact.detach().requires_grad_(True) + xr = x_local.detach().requires_grad_(True) + ref = radial_mix_reference(cr, xr, mixer.channel_basis, lmax) + grad_cr, grad_xr = torch.autograd.grad(ref, [cr, xr], grad_out) + + torch.testing.assert_close(grad_ca, grad_cr, **self.tol) + torch.testing.assert_close(grad_xa, grad_xr, **self.tol) + + def test_symbolic_make_fx_forward_backward_matches_eager(self): + """Symbolic FX captures the mixer forward and its inference-force graph.""" + lmax = 3 + mixer = self._mixer(lmax) + x_local, _, compact = self._inputs(mixer, seed=7) + channel_basis = mixer.channel_basis + grad_seed = torch.randn_like(x_local) + + def forward_and_grad(compact, x_local): + compact_req = compact.detach().requires_grad_(True) + x_req = x_local.detach().requires_grad_(True) + out = radial_mix_block(compact_req, x_req, channel_basis, lmax) + grad_compact, grad_x = torch.autograd.grad( + out, (compact_req, x_req), grad_seed + ) + return out, grad_compact, grad_x + + eager = forward_and_grad(compact, x_local) + traced = make_fx( + forward_and_grad, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )(compact, x_local) + for got, want in zip(traced(compact, x_local), eager, strict=True): + torch.testing.assert_close(got, want, **self.tol) + self.assertIn("sezm_triton.radial_mix_block", traced.code) + + if __name__ == "__main__": unittest.main() From 345e3780a56c4c05247e7b5936f904152d4303c9 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 10 Jun 2026 14:44:52 +0800 Subject: [PATCH 15/18] feat: layout opt --- deepmd/pt/model/descriptor/sezm_nn/so2.py | 42 ++- .../descriptor/sezm_nn/triton/radial_mix.py | 160 +++++--- .../descriptor/sezm_nn/triton/so2_rotation.py | 347 ++++++++++++++++++ .../pt/model/test_descriptor_sezm_triton.py | 50 +++ 4 files changed, 541 insertions(+), 58 deletions(-) diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py index b949769c58..9ff0c9fac4 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/so2.py +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -1035,7 +1035,7 @@ def __init__( self._rotate_back_fn = None if self.use_triton_infer: from .triton.so2_rotation import ( - rotate_back_block, + rotate_back_block_so2, rotate_back_dense, rotate_to_local_block, rotate_to_local_dense, @@ -1045,7 +1045,10 @@ def __init__( self._rotate_to_local_fn = lambda x, src, wigner: rotate_to_local_block( x, src, wigner, self.lmax ) - self._rotate_back_fn = lambda x_local, wigner: rotate_back_block( + # The block kernel reads the (E, F, D_m, Cf) focus layout directly, + # so the rotate-back path passes ``x_local`` before the global + # reshape and the transpose-back copy is skipped (see Step 7). + self._rotate_back_fn = lambda x_local, wigner: rotate_back_block_so2( x_local, wigner, self.lmax ) else: @@ -1695,27 +1698,32 @@ def apply_bias_correction( ) x_local = x_local * alpha.unsqueeze(-1).unsqueeze(-1) - # Restore reduced global layout for inverse rotation - x_local = x_local.transpose(1, 2).contiguous() # (E, D_m, F, Cf) - x_local = x_local.reshape( - n_edge, self.reduced_dim, self.hidden_channels - ) # (E, D_m, C_wide) - # === Step 7. Rotate back to global frame === with nvtx_range("SO2Conv/rotate_back"): Dt_full = edge_cache.Dt_full - if self.use_triton_infer and not self.training: + if self.use_triton_infer and self.mmax == 1 and not self.training: + # The block kernel consumes the (E, F, D_m, Cf) focus layout in + # place, folding the inverse transpose into its channel addressing. x_message = self._rotate_back_fn(x_local, Dt_full) # (E, D, C_wide) else: - Dt_from_m = project_Dt_from_m( - Dt_full=Dt_full, - coeff_index_m=self.coeff_index_m, - ebed_dim_full=self.ebed_dim_full, - cache=edge_cache.Dt_from_m_cache, - key_lmax=self.lmax, - key_mmax=self.mmax, + # Restore reduced global layout (E, D_m, C_wide) for inverse rotation. + x_local = ( + x_local.transpose(1, 2) + .contiguous() + .reshape(n_edge, self.reduced_dim, self.hidden_channels) ) - x_message = torch.bmm(Dt_from_m, x_local) # (E, D, C_wide) + if self.use_triton_infer and not self.training: + x_message = self._rotate_back_fn(x_local, Dt_full) # (E, D, C_wide) + else: + Dt_from_m = project_Dt_from_m( + Dt_full=Dt_full, + coeff_index_m=self.coeff_index_m, + ebed_dim_full=self.ebed_dim_full, + cache=edge_cache.Dt_from_m_cache, + key_lmax=self.lmax, + key_mmax=self.mmax, + ) + x_message = torch.bmm(Dt_from_m, x_local) # (E, D, C_wide) # Reduced layouts keep only 2*mmax+1 orders for l>mmax. Applying the # inverse-rotation degree rescale after the global lift restores the # full-basis amplitude expected by the block output contract. diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py b/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py index 9f34c41a69..ba9cf5bd52 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py @@ -290,22 +290,17 @@ def _radial_mix_fwd_kernel( ) @triton.jit - def _mix_bwd_block( + def _mix_bwd_grad_x_block( edge, chan, cmask, go_ptr, - x_ptr, k_ptr, cb_ptr, gx_ptr, - gk_ptr, go_se, go_sr, go_sc, - x_se, - x_sr, - x_sc, k_se, k_sk, k_sr, @@ -314,21 +309,16 @@ def _mix_bwd_block( gx_se, gx_sr, gx_sc, - gk_se, - gk_sk, - gk_sr, COEFF0: tl.constexpr, COMPACT0: tl.constexpr, NUM_L: tl.constexpr, RANK: tl.constexpr, ): - """Backward of one diagonal block. + """Input gradient of one diagonal block. - ``grad_x[i] = sum_r cb[r] sum_o K_r[o,i] grad_out[o]`` and - ``grad_K_r[o,i] = sum_c cb[r,c] x[i,c] grad_out[o,c]``. Both accumulators - are scattered with ``atomic_add`` into the zero-initialized outputs: the - ``m = -1`` and ``m = +1`` blocks share the ``compact`` slots, so their - contributions must sum. + Computes ``grad_x[i] = sum_r cb[r] sum_o K_r[o,i] grad_out[o]``. Each edge + owns its rows and the three blocks address disjoint coefficient rows, so + the result is written once with a plain store rather than an atomic add. """ for i in tl.static_range(0, NUM_L): grad_x = tl.zeros(chan.shape, dtype=tl.float32) @@ -351,37 +341,86 @@ def _mix_bwd_block( ).to(tl.float32) partial += kval * go_vec grad_x += cb_vec * partial - tl.atomic_add( + tl.store( gx_ptr + edge * gx_se + (COEFF0 + i) * gx_sr + chan * gx_sc, - grad_x, + grad_x.to(gx_ptr.dtype.element_ty), mask=cmask, ) + + @triton.jit + def _mix_bwd_grad_k_block( + edge, + chan, + cmask, + go_ptr, + x_ptr, + cb_ptr, + gk_ptr, + go_se, + go_sr, + go_sc, + x_se, + x_sr, + x_sc, + cb_sr, + cb_sc, + gk_se, + gk_sk, + gk_sr, + COEFF0: tl.constexpr, + COEFF1: tl.constexpr, + COMPACT0: tl.constexpr, + NUM_L: tl.constexpr, + RANK: tl.constexpr, + SHARED: tl.constexpr, + ): + """Kernel gradient of one diagonal block. + + Computes ``grad_K_r[o,i] = sum_c cb[r,c] x[i,c] grad_out[o,c]``. The + ``m = -1`` and ``m = +1`` blocks (``SHARED``) write the same ``compact`` + slots; their contributions are summed in registers and stored once, which + removes the atomic add and the zero-initialization the original required. + """ for o in tl.static_range(0, NUM_L): go_vec = tl.load( go_ptr + edge * go_se + (COEFF0 + o) * go_sr + chan * go_sc, mask=cmask, other=0.0, ).to(tl.float32) + if SHARED: + go_vec_sh = tl.load( + go_ptr + edge * go_se + (COEFF1 + o) * go_sr + chan * go_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) for i in tl.static_range(0, NUM_L): x_vec = tl.load( x_ptr + edge * x_se + (COEFF0 + i) * x_sr + chan * x_sc, mask=cmask, other=0.0, ).to(tl.float32) + prod = go_vec * x_vec + if SHARED: + x_vec_sh = tl.load( + x_ptr + edge * x_se + (COEFF1 + i) * x_sr + chan * x_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + prod += go_vec_sh * x_vec_sh for r in tl.static_range(0, RANK): cb_vec = tl.load( cb_ptr + r * cb_sr + chan * cb_sc, mask=cmask, other=0.0 ).to(tl.float32) - grad_k = tl.sum(tl.where(cmask, go_vec * x_vec * cb_vec, 0.0)) - tl.atomic_add( + grad_k = tl.sum(tl.where(cmask, prod * cb_vec, 0.0)) + tl.store( gk_ptr + edge * gk_se + (COMPACT0 + i * NUM_L + o) * gk_sk + r * gk_sr, - grad_k, + grad_k.to(gk_ptr.dtype.element_ty), ) - @triton.autotune(configs=_CONFIGS, key=_KEY, reset_to_zero=["gx_ptr", "gk_ptr"]) + @triton.autotune(configs=_CONFIGS, key=_KEY) @triton.jit def _radial_mix_bwd_kernel( go_ptr, @@ -417,13 +456,12 @@ def _radial_mix_bwd_kernel( chan = tl.arange(0, BLOCK_C) cmask = chan < channels num_l0: tl.constexpr = LMAX + 1 - strides = ( + + # === Step 1. Input gradient: three disjoint coefficient blocks === + grad_x_strides = ( go_se, go_sr, go_sc, - x_se, - x_sr, - x_sc, k_se, k_sk, k_sr, @@ -432,57 +470,95 @@ def _radial_mix_bwd_kernel( gx_se, gx_sr, gx_sc, - gk_se, - gk_sk, - gk_sr, ) - _mix_bwd_block( + _mix_bwd_grad_x_block( edge, chan, cmask, go_ptr, - x_ptr, k_ptr, cb_ptr, gx_ptr, - gk_ptr, - *strides, + *grad_x_strides, 0, 0, num_l0, RANK, ) - _mix_bwd_block( + _mix_bwd_grad_x_block( edge, chan, cmask, go_ptr, - x_ptr, k_ptr, cb_ptr, gx_ptr, - gk_ptr, - *strides, + *grad_x_strides, num_l0, num_l0 * num_l0, LMAX, RANK, ) - _mix_bwd_block( + _mix_bwd_grad_x_block( edge, chan, cmask, go_ptr, - x_ptr, k_ptr, cb_ptr, gx_ptr, + *grad_x_strides, + num_l0 + LMAX, + num_l0 * num_l0, + LMAX, + RANK, + ) + + # === Step 2. Kernel gradient: m=0 block, then summed m=+-1 blocks === + grad_k_strides = ( + go_se, + go_sr, + go_sc, + x_se, + x_sr, + x_sc, + cb_sr, + cb_sc, + gk_se, + gk_sk, + gk_sr, + ) + _mix_bwd_grad_k_block( + edge, + chan, + cmask, + go_ptr, + x_ptr, + cb_ptr, gk_ptr, - *strides, + *grad_k_strides, + 0, + 0, + 0, + num_l0, + RANK, + False, + ) + _mix_bwd_grad_k_block( + edge, + chan, + cmask, + go_ptr, + x_ptr, + cb_ptr, + gk_ptr, + *grad_k_strides, + num_l0, num_l0 + LMAX, num_l0 * num_l0, LMAX, RANK, + True, ) @@ -539,8 +615,10 @@ def _launch_backward( ) -> tuple[Tensor, Tensor]: n_edge, reduced_dim, channels = x_local.shape rank = int(compact.shape[-1]) - grad_x = torch.zeros_like(x_local) - grad_compact = torch.zeros_like(compact) + # Every output element is written exactly once (input rows are disjoint and + # the shared m=+-1 kernel slots are summed in-register), so no zero-init. + grad_x = torch.empty_like(x_local) + grad_compact = torch.empty_like(compact) if n_edge == 0: return grad_compact, grad_x _radial_mix_bwd_kernel[(n_edge,)]( diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py b/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py index 1d14037266..decfc0ac2d 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/so2_rotation.py @@ -929,6 +929,178 @@ def _bd_back_bwd_kernel( mask=cmask, ) + @triton.autotune(configs=_BD_CONFIGS, key=["channels"]) + @triton.jit + def _bd_back_so2_fwd_kernel( + xl_ptr, + w_ptr, + out_ptr, + n_edge, + channels, + xl_se, + xl_sf, + xl_sr, + xl_sc, + w_se, + w_sr, + w_sk, + o_se, + o_sd, + o_sc, + LMAX: tl.constexpr, + FOCUS_DIM: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + """Block-diagonal rotate_back reading the per-focus layout in place. + + ``out[e, l^2+j, c] = sum_m W[e, l^2+j, l^2+l+m] * x_local[e, f, (l,m), cf]`` + with ``c = f * FOCUS_DIM + cf``. Decoding the channel as ``(f, cf)`` folds + the ``(F, D_m, Cf) -> (D_m, C_wide)`` transpose into the addressing, so the + caller passes the SO(2) focus tensor without an explicit copy. + """ + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < channels + xl_co = (chan // FOCUS_DIM) * xl_sf + (chan % FOCUS_DIM) * xl_sc + for l in tl.static_range(0, LMAX + 1): + base = l * l + r0 = base + l + xl0 = tl.load( + xl_ptr + edge * xl_se + l * xl_sr + xl_co, mask=cmask, other=0.0 + ).to(tl.float32) + if l >= 1: + xl_m = tl.load( + xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + xl_co, + mask=cmask, + other=0.0, + ).to(tl.float32) + xl_p = tl.load( + xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + xl_co, + mask=cmask, + other=0.0, + ).to(tl.float32) + for j in tl.static_range(0, 2 * l + 1): + d = base + j + acc = tl.load(w_ptr + edge * w_se + d * w_sr + r0 * w_sk) * xl0 + if l >= 1: + acc += ( + tl.load(w_ptr + edge * w_se + d * w_sr + (r0 - 1) * w_sk) * xl_m + ) + acc += ( + tl.load(w_ptr + edge * w_se + d * w_sr + (r0 + 1) * w_sk) * xl_p + ) + tl.store( + out_ptr + edge * o_se + d * o_sd + chan * o_sc, + acc.to(out_ptr.dtype.element_ty), + mask=cmask, + ) + + @triton.autotune(configs=_BD_CONFIGS, key=["channels"]) + @triton.jit + def _bd_back_so2_bwd_kernel( + go_ptr, + xl_ptr, + w_ptr, + gxl_ptr, + gw_ptr, + n_edge, + channels, + go_se, + go_sd, + go_sc, + xl_se, + xl_sf, + xl_sr, + xl_sc, + w_se, + w_sr, + w_sk, + gxl_se, + gxl_sf, + gxl_sr, + gxl_sc, + gw_se, + gw_sr, + gw_sk, + LMAX: tl.constexpr, + FOCUS_DIM: tl.constexpr, + BLOCK_C: tl.constexpr, + ): + """Backward of :func:`_bd_back_so2_fwd_kernel`. + + Writes ``grad_x_local`` in the per-focus layout (decoding the channel as + ``(f, cf)`` exactly as the forward) and accumulates ``grad_W`` over the + full channel width, i.e. summed across focus streams. + """ + edge = tl.program_id(0).to(tl.int64) + chan = tl.arange(0, BLOCK_C) + cmask = chan < channels + xl_co = (chan // FOCUS_DIM) * xl_sf + (chan % FOCUS_DIM) * xl_sc + gxl_co = (chan // FOCUS_DIM) * gxl_sf + (chan % FOCUS_DIM) * gxl_sc + for l in tl.static_range(0, LMAX + 1): + base = l * l + r0 = base + l + xl0 = tl.load( + xl_ptr + edge * xl_se + l * xl_sr + xl_co, mask=cmask, other=0.0 + ).to(tl.float32) + gxl0 = tl.zeros((BLOCK_C,), dtype=tl.float32) + if l >= 1: + xl_m = tl.load( + xl_ptr + edge * xl_se + (LMAX + l) * xl_sr + xl_co, + mask=cmask, + other=0.0, + ).to(tl.float32) + xl_p = tl.load( + xl_ptr + edge * xl_se + (2 * LMAX + l) * xl_sr + xl_co, + mask=cmask, + other=0.0, + ).to(tl.float32) + gxl_m = tl.zeros((BLOCK_C,), dtype=tl.float32) + gxl_p = tl.zeros((BLOCK_C,), dtype=tl.float32) + for j in tl.static_range(0, 2 * l + 1): + d = base + j + go_d = tl.load( + go_ptr + edge * go_se + d * go_sd + chan * go_sc, + mask=cmask, + other=0.0, + ).to(tl.float32) + gxl0 += tl.load(w_ptr + edge * w_se + d * w_sr + r0 * w_sk) * go_d + tl.store( + gw_ptr + edge * gw_se + d * gw_sr + r0 * gw_sk, + tl.sum(go_d * xl0).to(gw_ptr.dtype.element_ty), + ) + if l >= 1: + gxl_m += ( + tl.load(w_ptr + edge * w_se + d * w_sr + (r0 - 1) * w_sk) * go_d + ) + gxl_p += ( + tl.load(w_ptr + edge * w_se + d * w_sr + (r0 + 1) * w_sk) * go_d + ) + tl.store( + gw_ptr + edge * gw_se + d * gw_sr + (r0 - 1) * gw_sk, + tl.sum(go_d * xl_m).to(gw_ptr.dtype.element_ty), + ) + tl.store( + gw_ptr + edge * gw_se + d * gw_sr + (r0 + 1) * gw_sk, + tl.sum(go_d * xl_p).to(gw_ptr.dtype.element_ty), + ) + tl.store( + gxl_ptr + edge * gxl_se + l * gxl_sr + gxl_co, + gxl0.to(gxl_ptr.dtype.element_ty), + mask=cmask, + ) + if l >= 1: + tl.store( + gxl_ptr + edge * gxl_se + (LMAX + l) * gxl_sr + gxl_co, + gxl_m.to(gxl_ptr.dtype.element_ty), + mask=cmask, + ) + tl.store( + gxl_ptr + edge * gxl_se + (2 * LMAX + l) * gxl_sr + gxl_co, + gxl_p.to(gxl_ptr.dtype.element_ty), + mask=cmask, + ) + # ====================================================================== # Triton launch wrappers @@ -1290,6 +1462,79 @@ def _launch_bd_back_bwd( return grad_x_local, grad_wigner +def _launch_bd_back_so2_fwd(x_local_4d: Tensor, wigner: Tensor, lmax: int) -> Tensor: + n_edge, n_focus, _reduced, focus_dim = (int(s) for s in x_local_4d.shape) + channels = n_focus * focus_dim + dim_full = (lmax + 1) ** 2 + out = torch.empty( + (n_edge, dim_full, channels), dtype=x_local_4d.dtype, device=x_local_4d.device + ) + if n_edge == 0: + return out + _bd_back_so2_fwd_kernel[(n_edge,)]( + x_local_4d, + wigner, + out, + n_edge, + channels, + x_local_4d.stride(0), + x_local_4d.stride(1), + x_local_4d.stride(2), + x_local_4d.stride(3), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + LMAX=lmax, + FOCUS_DIM=focus_dim, + BLOCK_C=_tile_dim(channels), + ) + return out + + +def _launch_bd_back_so2_bwd( + grad_out: Tensor, x_local_4d: Tensor, wigner: Tensor, lmax: int +) -> tuple[Tensor, Tensor]: + n_edge, n_focus, _reduced, focus_dim = (int(s) for s in x_local_4d.shape) + channels = n_focus * focus_dim + grad_x_local = torch.empty_like(x_local_4d) + grad_wigner = torch.zeros_like(wigner) + if n_edge == 0: + return grad_x_local, grad_wigner + _bd_back_so2_bwd_kernel[(n_edge,)]( + grad_out, + x_local_4d, + wigner, + grad_x_local, + grad_wigner, + n_edge, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x_local_4d.stride(0), + x_local_4d.stride(1), + x_local_4d.stride(2), + x_local_4d.stride(3), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + grad_x_local.stride(0), + grad_x_local.stride(1), + grad_x_local.stride(2), + grad_x_local.stride(3), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + LMAX=lmax, + FOCUS_DIM=focus_dim, + BLOCK_C=_tile_dim(channels), + ) + return grad_x_local, grad_wigner + + # ====================================================================== # Dispatch helpers (triton on CUDA float, eager otherwise) # ====================================================================== @@ -1609,3 +1854,105 @@ def rotate_back_block(x_local: Tensor, wigner: Tensor, lmax: int) -> Tensor: coefficient-index tensor. """ return _block_back_op(x_local, wigner, int(lmax)) + + +# ====================================================================== +# Layout-aware block rotate_back (per-focus SO(2) layout, mmax == 1) +# ====================================================================== +# Consumes the (E, F, D_m, Cf) focus layout produced by the SO(2) layers so the +# caller can skip the ``transpose(1, 2).contiguous()`` that would otherwise +# materialize (E, D_m, F * Cf) before the inverse rotation. + + +def _block_rotate_back_so2_impl( + x_local_4d: Tensor, wigner: Tensor, lmax: int +) -> Tensor: + if not _use_triton(x_local_4d): + n_edge, n_focus, reduced_dim, focus_dim = x_local_4d.shape + x_std = x_local_4d.transpose(1, 2).reshape( + n_edge, reduced_dim, n_focus * focus_dim + ) + coeff = build_m_major_index(int(lmax), 1, device=x_local_4d.device) + return rotate_back_reference(x_std, wigner, coeff, (int(lmax) + 1) ** 2) + return _launch_bd_back_so2_fwd(x_local_4d, wigner, int(lmax)) + + +def _block_rotate_back_so2_bwd_impl( + grad_out: Tensor, x_local_4d: Tensor, wigner: Tensor, lmax: int +) -> tuple[Tensor, Tensor]: + if not _use_triton(x_local_4d): + n_edge, n_focus, reduced_dim, focus_dim = x_local_4d.shape + x_std = x_local_4d.transpose(1, 2).reshape( + n_edge, reduced_dim, n_focus * focus_dim + ) + coeff = build_m_major_index(int(lmax), 1, device=x_local_4d.device) + grad_x_std, grad_wigner = _rotate_back_bwd_eager( + grad_out, x_std, wigner, coeff, (int(lmax) + 1) ** 2 + ) + grad_x_local = grad_x_std.reshape( + n_edge, reduced_dim, n_focus, focus_dim + ).transpose(1, 2) + return grad_x_local, grad_wigner + return _launch_bd_back_so2_bwd(grad_out.contiguous(), x_local_4d, wigner, int(lmax)) + + +_block_back_so2_op = torch.library.custom_op( + "sezm_triton::rotate_back_block_so2", mutates_args=() +)(_block_rotate_back_so2_impl) + +_block_back_so2_bwd_op = torch.library.custom_op( + "sezm_triton::rotate_back_block_so2_bwd", mutates_args=() +)(_block_rotate_back_so2_bwd_impl) + + +@_block_back_so2_op.register_fake +def _(x_local_4d, wigner, lmax): + n_edge, n_focus, _reduced, focus_dim = x_local_4d.shape + return x_local_4d.new_empty((n_edge, (int(lmax) + 1) ** 2, n_focus * focus_dim)) + + +@_block_back_so2_bwd_op.register_fake +def _(grad_out, x_local_4d, wigner, lmax): + return torch.empty_like(x_local_4d), torch.empty_like(wigner) + + +def _block_back_so2_setup_context(ctx, inputs, output): + x_local_4d, wigner, lmax = inputs + ctx.save_for_backward(x_local_4d, wigner) + ctx.lmax = lmax + + +def _block_back_so2_backward(ctx, grad_out): + x_local_4d, wigner = ctx.saved_tensors + grad_x_local, grad_wigner = _block_back_so2_bwd_op( + grad_out, x_local_4d, wigner, ctx.lmax + ) + return grad_x_local, grad_wigner, None + + +_block_back_so2_op.register_autograd( + _block_back_so2_backward, setup_context=_block_back_so2_setup_context +) + + +def rotate_back_block_so2(x_local_4d: Tensor, wigner: Tensor, lmax: int) -> Tensor: + """Block-diagonal ``local -> global`` rotation reading the per-focus layout. + + Parameters + ---------- + x_local_4d : Tensor + Local features with shape (E, F, reduced_dim, Cf) in the canonical m-major + ``mmax=1`` layout, where C_wide = F * Cf. + wigner : Tensor + Transposed Wigner-D with shape (E, D, D), D = (lmax + 1) ** 2. + lmax : int + Maximum degree. + + Returns + ------- + Tensor + Global-frame message with shape (E, D, C_wide). The per-focus to packed + channel mapping ``c = f * Cf + cf`` folds the inverse transpose into the + kernel addressing, avoiding an explicit copy. + """ + return _block_back_so2_op(x_local_4d, wigner, int(lmax)) diff --git a/source/tests/pt/model/test_descriptor_sezm_triton.py b/source/tests/pt/model/test_descriptor_sezm_triton.py index 4497361223..e74c31d72a 100644 --- a/source/tests/pt/model/test_descriptor_sezm_triton.py +++ b/source/tests/pt/model/test_descriptor_sezm_triton.py @@ -42,6 +42,7 @@ from deepmd.pt.model.descriptor.sezm_nn.triton.so2_rotation import ( TRITON_ROTATION_AVAILABLE, rotate_back_block, + rotate_back_block_so2, rotate_back_dense, rotate_back_reference, rotate_to_local_block, @@ -220,6 +221,55 @@ def test_eager_rotate_back_forward_backward_matches_reference(self): xl0, w0, coeff_index, dim = self._local_inputs(lmax, seed=lmax) self._assert_back_matches_reference(xl0, w0, coeff_index, dim) + def test_rotate_back_so2_matches_block_on_focus_layout(self): + """The layout-aware rotate_back reads the per-focus layout (E, F, D_m, Cf) + directly and reproduces the standard block kernel on the transposed + (E, D_m, F * Cf) input, forward and backward (including grad_wigner on the + block entries). This is the in-place transpose the SeZM SO(2) pipeline + avoids materializing. + """ + n_focus, focus_dim = 2, 8 + for lmax in (2, 3, 4, 5): + with self.subTest(lmax=lmax): + gen = torch.Generator(device=self.device).manual_seed(lmax) + reduced = 3 * lmax + 1 + mask = _block_mask(lmax, self.device) + x4 = torch.randn( + self.n_edge, + n_focus, + reduced, + focus_dim, + device=self.device, + dtype=self.dtype, + generator=gen, + ) + w0 = _block_diagonal_wigner( + self.n_edge, lmax, self.device, self.dtype, gen + ) + + xa = x4.clone().requires_grad_(True) + wa = w0.clone().requires_grad_(True) + out = rotate_back_block_so2(xa, wa, lmax) + + x_std = x4.transpose(1, 2).reshape( + self.n_edge, reduced, n_focus * focus_dim + ) + xr = x_std.clone().requires_grad_(True) + wr = w0.clone().requires_grad_(True) + ref = rotate_back_block(xr, wr, lmax) + torch.testing.assert_close(out, ref, **self.tol) + + grad_out = torch.randn_like(ref) + gxa, gwa = torch.autograd.grad( + out, [xa, wa], grad_out, retain_graph=True + ) + gxr, gwr = torch.autograd.grad(ref, [xr, wr], grad_out) + gxa_std = gxa.transpose(1, 2).reshape( + self.n_edge, reduced, n_focus * focus_dim + ) + torch.testing.assert_close(gxa_std, gxr, **self.tol) + torch.testing.assert_close(gwa[:, mask], gwr[:, mask], **self.tol) + def test_symbolic_make_fx_rotate_to_local_forward_backward_matches_eager(self): """Symbolic FX captures rotate_to_local forward and autograd graph.""" lmax = 3 From 1c3f110c5dfb69022bcc42b047c29390b9fd14ca Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 10 Jun 2026 16:01:42 +0800 Subject: [PATCH 16/18] bugfix --- deepmd/pt/model/descriptor/sezm_nn/so2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py index 9ff0c9fac4..de2ba59e9a 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/so2.py +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -360,7 +360,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self._cached_weight = weight.detach() # === Step 3. Block-diagonal matmul over focus streams + reshape back === - out_flat = self._block_diagonal_matmul(x_flat, weight) + # On CPU the block ``torch.cat`` lowers to a nested masked select that + # trips an Inductor AVX2 C++ codegen bug under compile, so use the + # numerically identical dense einsum there; GPU keeps the block-diag opt. + if x_flat.is_cuda: + out_flat = self._block_diagonal_matmul(x_flat, weight) + else: + out_flat = torch.einsum("efi,ifo->efo", x_flat, weight) out = out_flat.reshape( n_edge, self.n_focus, self.reduced_dim, self.out_channels ) From 1dcb1c5c146281bbe9aa4c32e38920172674e1cc Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 10 Jun 2026 16:58:34 +0800 Subject: [PATCH 17/18] fix freeze --- .../descriptor/sezm_nn/triton/radial_mix.py | 66 ++++++++++++++++-- source/tests/pt/model/test_sezm_export.py | 67 +++++++++++++++++-- 2 files changed, 120 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py b/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py index ba9cf5bd52..12ba2de25c 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/radial_mix.py @@ -127,13 +127,65 @@ def radial_mix_reference( def _radial_mix_backward_reference( grad_out: Tensor, compact: Tensor, x_local: Tensor, channel_basis: Tensor, lmax: int ) -> tuple[Tensor, Tensor]: - """Eager backward returning ``(grad_compact, grad_x_local)`` via autograd.""" - with torch.enable_grad(): - compact_req = compact.detach().requires_grad_(True) - x_req = x_local.detach().requires_grad_(True) - out = radial_mix_reference(compact_req, x_req, channel_basis, lmax) - grad_compact, grad_x = torch.autograd.grad(out, [compact_req, x_req], grad_out) - return grad_compact, grad_x + """Closed-form eager backward of :func:`radial_mix_reference`. + + Gradients are evaluated analytically per diagonal block, mirroring the + contractions of the Triton backward. A closed form is required rather than a + nested ``autograd.grad``: this routine is the CPU backend of the + ``radial_mix_block_bwd`` operator, which carries no autograd formula and is + consequently dispatched under ``_AutoDispatchBelowAutograd`` whenever the + force graph is replayed without grad (the SeZM ``.pt2`` freeze does so under + :func:`torch.no_grad`). That guard excludes the autograd key, so a nested + ``autograd.grad`` would observe an output without a ``grad_fn``. + + Parameters + ---------- + grad_out : Tensor + Upstream gradient with shape ``(E, reduced_dim, C)``. + compact : Tensor + Projected radial degree kernel with shape ``(E, degree_kernel_size, R)``. + x_local : Tensor + Edge-local reduced features with shape ``(E, reduced_dim, C)``. + channel_basis : Tensor + Per-rank channel basis with shape ``(R, C)``. + lmax : int + Maximum spherical-harmonic degree. + + Returns + ------- + tuple[Tensor, Tensor] + Gradients ``(grad_compact, grad_x_local)``, matching ``compact`` and + ``x_local`` in shape respectively. + """ + n_edge, reduced_dim, channels = x_local.shape + grad_x_local = torch.zeros_like(x_local) + grad_compact = torch.zeros_like(compact) + for coeff0, comp0, num_l in _block_layout(int(lmax)): + # Forward of this block (see ``radial_mix_reference``): + # out[e, o, c] = sum_{i, r} K[e, o, i, r] * x[e, i, c] * cb[r, c] + # with K[e, o, i, r] = compact[e, comp0 + i * num_l + o, r]. + k_block = ( + compact[:, comp0 : comp0 + num_l * num_l, :] + .reshape(n_edge, num_l, num_l, -1) + .permute(0, 2, 1, 3) + ) # (E, o, i, R) + x_block = x_local[:, coeff0 : coeff0 + num_l, :] # (E, i, C) + g_block = grad_out[:, coeff0 : coeff0 + num_l, :] # (E, o, C) + + # grad_x[e, i, c] = sum_r cb[r, c] * sum_o K[e, o, i, r] * g[e, o, c]. + gx = torch.einsum("eoir,eoc->eicr", k_block, g_block) # (E, i, C, R) + grad_x_local[:, coeff0 : coeff0 + num_l, :] += torch.einsum( + "eicr,rc->eic", gx, channel_basis + ) + + # grad_K[e, o, i, r] = sum_c cb[r, c] * x[e, i, c] * g[e, o, c], scattered + # back to the compact slot comp0 + i * num_l + o. The shared m = +-1 + # blocks address the same slots, so the in-place add accumulates both. + gk = torch.einsum("eoc,eic,rc->eoir", g_block, x_block, channel_basis) + grad_compact[:, comp0 : comp0 + num_l * num_l, :] += gk.permute( + 0, 2, 1, 3 + ).reshape(n_edge, num_l * num_l, -1) + return grad_compact, grad_x_local # ====================================================================== diff --git a/source/tests/pt/model/test_sezm_export.py b/source/tests/pt/model/test_sezm_export.py index c1d143c21c..7398cf7cee 100644 --- a/source/tests/pt/model/test_sezm_export.py +++ b/source/tests/pt/model/test_sezm_export.py @@ -13,6 +13,7 @@ import contextlib import copy import json +import os import tempfile import unittest import zipfile @@ -216,18 +217,24 @@ class TestSeZMExportPipeline(_ClearDefaultDeviceTestCase): it must reproduce the eager result exactly. Drift here implies a bug in ``forward_common_lower_exportable`` or the dynamic-shape spec, not in AOTI. The pipeline is built once per class because - ``make_fx`` and ``.pte`` round-trip dominate wall time. + ``make_fx`` and ``.pte`` round-trip dominate wall time. Subclasses + set ``TRITON_INFER`` to drive the identical pipeline through the + opt-in Triton inference kernels. """ + # ``DP_TRITON_INFER`` policy applied while the model is constructed. + TRITON_INFER = "0" + @classmethod def setUpClass(cls) -> None: super().setUpClass() try: - cls.model = _build_tiny_sezm_model() - cls.sample_inputs = _make_sample(cls.model, nloc=7, start=2) - cls.traced, cls.loaded, cls._pte_tmp = cls._build_pipeline( - cls.model, cls.sample_inputs - ) + with mock.patch.dict(os.environ, {"DP_TRITON_INFER": cls.TRITON_INFER}): + cls.model = _build_tiny_sezm_model() + cls.sample_inputs = _make_sample(cls.model, nloc=7, start=2) + cls.traced, cls.loaded, cls._pte_tmp = cls._build_pipeline( + cls.model, cls.sample_inputs + ) except Exception: super().tearDownClass() raise @@ -327,6 +334,54 @@ def test_loaded_pte_matches_eager_different_shape(self) -> None: ) +class TestSeZMExportPipelineTritonInfer(TestSeZMExportPipeline): + """The same trace / ``.pte`` pipeline exercised with ``DP_TRITON_INFER=1``. + + Inheriting the parity suite asserts the Triton-enabled model still traces, + exports, and reloads, and that the loaded ``.pte`` reproduces its eager + forward — including the force path, whose custom ``*_bwd`` ops run inside + ``_AutoDispatchBelowAutograd`` during the export's no-grad replay and must + therefore be closed-form. Two checks are added: the captured graph routes + through the custom ops, and the Triton ``.pte`` matches the dense (Triton-off) + inference, proving ``DP_TRITON_INFER`` swaps the implementation without + changing results. + """ + + TRITON_INFER = "1" + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + try: + with mock.patch.dict(os.environ, {"DP_TRITON_INFER": "0"}): + dense_model = _build_tiny_sezm_model() + cls.dense_out = _eager_forward(dense_model, cls.sample_inputs) + except Exception: + super().tearDownClass() + raise + + @classmethod + def tearDownClass(cls) -> None: + try: + if hasattr(cls, "dense_out"): + delattr(cls, "dense_out") + finally: + super().tearDownClass() + + def test_force_graph_carries_triton_ops(self) -> None: + """``DP_TRITON_INFER=1`` must route the descriptor through the custom ops.""" + code = self.traced.code + self.assertIn("radial_mix_block_bwd", code) + self.assertIn("rotate_to_local", code) + + def test_loaded_pte_matches_dense(self) -> None: + """The Triton-on ``.pte`` reproduces the dense-path inference.""" + loaded_out = self.loaded(*self.sample_inputs) + self._assert_dict_allclose( + self.dense_out, loaded_out, context="triton .pte vs dense eager" + ) + + class _FrozenPt2Fixture(_ClearDefaultDeviceTestCase): """Shared setUp/tearDown: freeze a tiny SeZM checkpoint to ``.pt2`` once. From 89c5b1d2714cff636c699c4cf731d98f632cc7a8 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 10 Jun 2026 18:07:58 +0800 Subject: [PATCH 18/18] use block diag when freezing --- deepmd/pt/entrypoints/freeze_pt2.py | 14 ++++++++++++++ deepmd/pt/model/descriptor/sezm_nn/so2.py | 20 ++++++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/entrypoints/freeze_pt2.py b/deepmd/pt/entrypoints/freeze_pt2.py index 7adce22d1d..c85671b178 100644 --- a/deepmd/pt/entrypoints/freeze_pt2.py +++ b/deepmd/pt/entrypoints/freeze_pt2.py @@ -43,6 +43,9 @@ from deepmd.dpmodel.utils.region import ( normalize_coord, ) +from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + SO2Linear, +) from deepmd.pt.model.model import ( get_model, ) @@ -518,6 +521,17 @@ def freeze_sezm_to_pt2( model.eval() model.to("cpu") + # The SO(2) linear mixer selects its block-diagonal vs dense matmul from a + # Python device branch that make_fx resolves at trace time. Since tracing + # always runs on CPU, pin the choice to the AOTI target device: non-CPU + # targets bake the block-diagonal contraction (which skips the structural + # off-|m| zeros); CPU targets keep the dense einsum that dodges the Inductor + # AVX2 codegen bug. + force_block_diag = target_device.type != "cpu" + for module in model.modules(): + if isinstance(module, SO2Linear): + module._force_block_diag_matmul = force_block_diag + _, sample_inputs_cpu = _resolve_nframes( model, nloc=7, diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py index de2ba59e9a..91839c82d4 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/so2.py +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -326,6 +326,11 @@ def __init__( # Invalidated on train() via overridden method below. self._cached_weight: torch.Tensor | None = None + # Export override for the block-diagonal vs dense matmul branch below. + # ``None`` keeps the runtime ``x_flat.is_cuda`` dispatch; the freeze sets + # it so the AOTI graph follows the *target* device, not the CPU trace. + self._force_block_diag_matmul: bool | None = None + # The assembled SO(2) weight is block-diagonal over |m| groups; the # forward contracts only the diagonal blocks (see _block_diagonal_matmul). # Each |m| group occupies a contiguous (in, out) block on the diagonal. @@ -360,10 +365,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self._cached_weight = weight.detach() # === Step 3. Block-diagonal matmul over focus streams + reshape back === - # On CPU the block ``torch.cat`` lowers to a nested masked select that - # trips an Inductor AVX2 C++ codegen bug under compile, so use the - # numerically identical dense einsum there; GPU keeps the block-diag opt. - if x_flat.is_cuda: + # The dense einsum is a CPU-only fallback: its block ``torch.cat`` lowering + # trips an Inductor AVX2 C++ codegen bug, so only CPU needs it. Every other + # device uses the block-diagonal contraction, which skips the structural + # off-|m| zeros. ``make_fx`` resolves this Python branch at trace time, so + # the freeze pins ``_force_block_diag_matmul`` to the AOTI target device + # (tracing always runs on CPU regardless of where the artifact will run). + if self._force_block_diag_matmul is None: + use_block_diag = not x_flat.is_cpu + else: + use_block_diag = self._force_block_diag_matmul + if use_block_diag: out_flat = self._block_diagonal_matmul(x_flat, weight) else: out_flat = torch.einsum("efi,ifo->efo", x_flat, weight)