From 8e107b327de7a4767d95d8f134fa62c74fa6ece4 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Mon, 11 May 2026 21:04:52 +1000 Subject: [PATCH 01/15] docs: add in-memory checkpoint / snapshot toolkit design note MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Spun off from the 2026-05-11 deformable-surface design discussion as a self-contained UW3 capability. Covers motivation (backtrack-on-failure, adaptive Δt retry, RK staging, predictor-corrector probes, crash recovery, bisection, debugging captures), the two existing on-disk paths (write_checkpoint and write_timestep), the state-as-dataclass serialisation contract for solver-internal Python state, the three-backend story (in-memory + on-disk-full-state + existing write_timestep unchanged), schema versioning, the swarm population-generation counter, eight architectural work items in dependency order, scope boundaries, and open implementation questions. Baseline-of-record for the feature/in-memory-checkpoint branch. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- .../design/in_memory_checkpoint_design.md | 474 ++++++++++++++++++ 1 file changed, 474 insertions(+) create mode 100644 docs/developer/design/in_memory_checkpoint_design.md diff --git a/docs/developer/design/in_memory_checkpoint_design.md b/docs/developer/design/in_memory_checkpoint_design.md new file mode 100644 index 00000000..4c3a3676 --- /dev/null +++ b/docs/developer/design/in_memory_checkpoint_design.md @@ -0,0 +1,474 @@ +# In-memory checkpoint as a general UW3 capability + +Design note. Spun off from the deformable-surface architectural +discussion (2026-05-11) as a self-contained capability that is bounded, +useful in its own right, and not specifically tied to free-surface code. +The companion note that motivated this is +`docs/developer/design/deformable_surface_metronome_design_note.md` (in +the `feature/exp-integrator-freesurface` worktree). + +## Motivation + +**Primary use case: backtrack past unstable timestepping.** A timestepper +hits an instability or a sudden regime change; the cleanest recovery is +to restore the last known-good state and continue with smaller Δt +(or a different scheme). Today this is done ad-hoc by users where it +happens at all. + +**Secondary uses, all sharing the same primitive:** + +- Multi-stage time integration (RK4 between stages — restore to + start-of-step, deform to next stage configuration, sample rate). The + free-surface integrator-zoo work is the immediate case; the same + primitive applies to any multi-stage scheme. +- Adaptive Δt with error estimate (take a step, estimate per-step + error, restore + retry if too large). +- Predictor-corrector probing (try a predictor, check the corrector + residual, fall back if not converging — relevant to the VEP work). +- Regime-change feeling-out (e.g., elastic predictor → check yield + surface → restore + plastic split if violated). + +All five want one thing: at one moment, capture *enough* state that the +system can be put back at any point afterwards as if nothing had +happened. + +## Reframe: same operation as on-disk checkpoint, different backend + +Rather than build a parallel snapshot mechanism, treat in-memory +snapshot as a *backend variant* of the existing checkpoint code. Two +benefits: + +1. **One source of truth for "what is system state."** Whatever set of + (mesh coords, MV DOFs, swarm positions, swarm-var values, algorithm + history, ...) the checkpoint serialiser captures becomes the + contract for in-memory snapshots too. Adding new state-bearers means + adding them once; both backends pick them up. +2. **The existing checkpoint code gets exercised harder and improved + as a side effect.** In-memory roundtrip is a cheap unit test — + every snapshot/restore cycle exercises the serialise/deserialise + paths, which surfaces gaps and bugs that disk-only checkpointing + exposes only at quarterly-test cadence. + +## Two checkpoint paths today (audit-confirmed) + +The audit (read-only investigation, current `main`) confirmed two +distinct on-disk paths in `src/underworld3/`: + +**Path A — PETScSection-based (native checkpoint).** +- `Mesh.write_checkpoint(...)` at `discretisation/discretisation_mesh.py:1892–1953`. +- Uses `dm.sectionView(viewer, subdm)` and `dm.globalVectorView(viewer, subdm, var._gvec)` + through a `PETSc.ViewerHDF5()`. +- Captures mesh topology, deformed coordinates, MV DOF values, and + swarm-variable values via the `_meshVar` proxy. +- Lower-level; bound to the producing DM. + +**Path B — write_timestep (visualization-oriented).** +- `Mesh.write_timestep(...)` at `discretisation/discretisation_mesh.py:1750–1830` + and `Swarm.write_timestep(...)` at `swarm.py:3726–3810`. +- Per-variable HDF5 + XDMF for ParaView; mixes `PETSc.ViewerHDF5()` + with direct `h5py` writes (`swarm.py:1772–1850`). +- More flexible — can be re-loaded at different resolution / + decomposition; bulkier; the user-facing visualisation pipeline. + +**For in-memory snapshot, Path A is the conceptual model.** Restore +goes back to the same DM, so the resolution/decomposition flexibility +of Path B is unneeded. But the in-memory backend will not actually use +the HDF5 Viewer — it will copy section structure + global-vector data +directly into numpy arrays. Same conceptual capture, different +mechanism. + +## What state must be captured + +Audit-informed inventory: + +**Captured today (in Path A):** +- DM topology and section +- Deformed mesh coordinates +- Mesh-variable global vectors (DOF values) — **including `DDt.psi_star`, + which is itself a mesh variable; its DOF data goes through the + section path automatically** +- Swarm-variable values (via mesh-DM proxy fields) + +**NOT captured today, required for full-state in-memory restore:** + +| state | location | priority | +|---|---|---| +| `DDt._dt_history` (variable-Δt BDF history; list of floats) | `systems/ddt.py:386` | **high** | +| `DDt._history_initialised` (bool) | `systems/ddt.py:383` | **high** | +| `DDt._n_solves_completed` (int) | `systems/ddt.py:384` | **high** | +| Binding `DDt instance ↔ psi_star MV(s)` | implicit in `__init__` | high | +| Simulation time, step counter | not in any object today | high | +| Parameter mutation history | `parameters.py:145` (`_history`) | medium | +| `Swarm._mesh_version` | `swarm.py:2421` | medium | +| Solver iteration counts / convergence history | scattered | low | + +**The DDt example is representative, not exceptional.** Algorithm-internal +scalar/list state on Python objects is the architecturally significant +gap — it lives in solver-adjacent Python objects, not in any PETScSection. +The in-memory snapshot must compose PETScSection state + Python-side +mutable state into a single token. The next section addresses how this +composition should be designed in general — not as a per-class bolt-on, +but as a contract that new algorithm-helper classes follow from day one. + +## General serialisation contract for solver-internal state + +The opportunity. Bolting per-class snapshot hooks onto each algorithm +helper as it appears would produce a half-baked system that silently +misses state whenever a new helper is added without a corresponding hook. +DDt today, the next adaptive-Δt controller tomorrow, the gamma estimator +the day after — the pattern recurs. + +This is the right moment to design a general serialisation contract for +solver-internal state, before it leaks into the user API. New +algorithm-helper classes should declare their state slots from day one; +existing classes (DDt, parameter mutation history, solver convergence +bookkeeping) get retrofitted as they're touched. + +**Three options for the contract.** + +(A) **Declarative slots.** Class declares a class-level list of +attribute names that constitute state. +```python +class DDt: + _state_attrs = ('_dt_history', '_history_initialised', + '_n_solves_completed') +``` +Snapshot copies named attrs; restore writes them back. Stringly-typed; +silent breakage when attribute names drift. Least invasive — works as a +mixin on any existing class. + +(B) **Explicit save/restore methods.** Class implements paired methods. +```python +class DDt: + def _save_state(self) -> dict: ... + def _restore_state(self, d: dict) -> None: ... +``` +Most flexible (transform on save, validate on restore). Most boilerplate; +easy to forget to update when adding a new state attribute. Standard +Python pickling pattern (`__getstate__` / `__setstate__`) is essentially +this option. + +(C) **State as a first-class object.** Computation and state are +separated; the State object is a dataclass / pydantic model that +trivially serialises. +```python +@dataclass +class DDtState: + dt_history: list[float] + history_initialised: bool + n_solves_completed: int + psi_star_var_names: list[str] # binding to MVs lives here too + +class DDt: + def __init__(self, ...): + self.state = DDtState(...) + # All mutations go via self.state. +``` +Most self-documenting; enables side benefits beyond serialisation +(deep-copy, equality testing, repr, schema-versioned migrations). Most +invasive — changes how new solver-internal classes are written. + +**Recommendation.** (C) for new code, (B)-style adapters for the small +number of existing classes that need retrofitting (primarily DDt and +parameter-mutation-history). (A) is rejected because the silent-drift +failure mode is exactly the kind of half-baked outcome we're trying to +avoid by designing this now. + +The decision matters because it sets the pattern for every +algorithm-internal class added over the next few years. The cost of +choosing wrong is that snapshot tokens silently miss state for any +class added without proper instrumentation; the cost of choosing right +is that new solver-internal classes get serialisation, deep-copy, and +equality testing automatically. + +**Side benefits of (C) beyond checkpoint.** +- `solver_a.state == solver_b.state` becomes a meaningful comparison — + useful for regression testing of solver-internal behaviour +- `repr(obj.state)` is automatic and useful for debugging +- snapshot tokens compose trivially: `{id(obj): obj.state.copy()}` +- schema versioning is tractable: a `_schema_version` field on each + State object plus a migration registry handles cross-version + compatibility (relevant for on-disk checkpoints that survive UW3 + upgrades) +- the `Snapshottable` interface becomes a one-liner: "has a `.state` + attribute that is a dataclass." + +**Bindings to PETSc-owned objects.** State objects must NOT hold direct +PETSc Vec / DM / MV handles — only stable identifiers (variable names, +mesh names) that can resolve back to the live object on restore. This +keeps tokens plain-Python and avoids the DM-lifecycle hazards +identified in earlier work. + +## A third backend: on-disk full-state snapshot + +The in-memory backend is one storage option. The same capture/restore +machinery supports a slower-but-persistent **on-disk full-state** +backend with no architectural changes — only the serialisation layer +differs. + +This is **distinct from the existing `write_timestep` on-disk path**, +which is selective (per-variable), designed for visualisation, emits +XDMF for ParaView, and does not restore solver-internal state. Three +backends serving three different needs: + +| dimension | existing `write_timestep` | new on-disk full-state | new in-memory | +|---|---|---|---| +| storage | HDF5 + XDMF, per-variable | HDF5, monolithic | dict of numpy arrays | +| selectivity | user picks variables | always full state | always full state | +| Python-side state | not captured | captured | captured | +| restorability | partial (re-init solvers, lose history) | bit-equivalent | bit-equivalent | +| persistence | survives process exit | survives process exit | intra-run only | +| speed | medium | slow (HDF5 I/O) | fast (RAM copy) | +| typical use | visualisation, restart-with-changes | crash recovery, bisection, debugging | RK staging, backtrack, adaptive Δt | + +The existing `write_timestep` continues to serve its role unchanged. +The new mechanism (in-memory + on-disk-full-state) addresses a +different need: faithful state restore for algorithmic uses where +"approximate restart" would silently corrupt the run. + +**Use cases the on-disk full-state backend opens up:** + +- **Crash recovery for long simulations.** Periodic snapshots written + to disk; on restart, restore from the most recent. No lost work + beyond the snapshot interval. +- **Cross-run resumption.** A simulation that ran to $t = T_1$ can be + picked up at $t = T_1$ in a later session, bit-equivalent — including + DDt history and any other algorithm-internal state. +- **Bisection / branching exploration.** Snapshot at a decision point; + try one parameter path; if unsatisfactory, restore and try another. + Powerful for sensitivity studies on long runs. +- **Debugging captures.** Snapshot at a problem point; examine offline; + iterate without re-running the costly setup. + +**What this adds to the design.** Three implications, all bounded: + +1. **Backend abstraction must serialise to bytes from day one.** The + in-memory backend stores numpy arrays directly; the on-disk backend + serialises them to HDF5. Both go through the same `save_*` / `load_*` + protocol on the backend interface — the abstraction is shaped by + needing both. +2. **Schema versioning becomes critical.** In-memory tokens are + short-lived; on-disk tokens may be loaded by a future UW3 version. + Each `State` dataclass carries a `_schema_version`; a migration + registry handles cross-version restore. (The on-disk + `write_timestep` path has no equivalent need because it doesn't + restore solver state.) +3. **Generation-counter semantics need a cross-process branch.** + Within-process: counter check ensures you don't restore an + invalidated snapshot. Across-process (restoring from disk on a + fresh run): the model is being initialised *from* the snapshot, + not restored *to* a previous state — counters are set to the + snapshot's values, no invalidation check needed. Restore code + distinguishes the two paths. + +## API shape + +Full-state always; backend chosen at snapshot time by passing (or +omitting) a path. + +```python +# Backend selection — same capture, different storage layer +token = model.snapshot() # in-memory (default) +model.snapshot(path='step42.snap.h5') # on-disk full-state + +model.restore(token) # in-memory restore +model.restore('step42.snap.h5') # on-disk restore + +# Existing per-variable selective on-disk path is unchanged: +mesh.write_timestep('step42.h5', ...) # visualisation; not full-state +``` + +Backends share a single `Snapshot` structure — only the serialisation +layer differs: + +```python +class Snapshot: + section_state: dict[DM, SectionAndVecData] # PETSc state + python_state: dict[ObjectId, ObjectState] # State dataclasses + generations: dict[ObjectId, int] # within-process invalidation + metadata: dict # sim time, step #, schema version +``` + +The in-memory backend stores `section_state` arrays as numpy buffers +directly; the on-disk backend serialises them to HDF5 datasets. The +Python-side `State` dataclasses serialise to HDF5 attributes (small +scalars / short lists) and groups (arrays). + +Tokens are plain Python / numpy — never PETSc Vec or DM handles. +On-disk "tokens" are paths to single HDF5 files. Either way, the +DM-lifecycle hazards identified in earlier work do not apply. + +## Generation counter for swarm invalidation + +Snapshots are valid until the population that produced them changes. +The audit identified the mutation sites: + +| file | line(s) | call | +|---|---|---| +| `swarm.py` | 3083, 3085, 3109 | `populate()` → `dm.addNPoints()` | +| `swarm.py` | 3365 | `add_particles_with_coordinates()` | +| `swarm.py` | 3449 | `add_particles_with_global_coordinates()` | +| `swarm.py` | 3223, 3382 | `migrate(remove_sent_points=True)` | +| `swarm.py` | 4298 | remesh/repopulate path | +| `discretisation_mesh.py` | 3090 | `Mesh.adapt()` (indirect via `_mesh_version`) | + +A single `Swarm._population_generation` counter, incremented at all +seven sites, is sufficient. `Mesh._mesh_version` already exists for the +mesh-side analogue. + +```python +def swarm.restore(token): + if swarm._population_generation != token.generation_at_snapshot: + raise SnapshotInvalidatedError( + "swarm population changed since snapshot — restore not safe") + # ... write positions back, write svar values back, migrate +``` + +Constraint documented as part of the contract: snapshots cannot survive +a population-change event. Consumers that take long-lived snapshots +across such events get a clear error rather than silent corruption. + +## Architectural work required + +In rough dependency order: + +1. **Backend abstraction layer.** Extract the "capture" logic from the + "store" logic in `Mesh.write_checkpoint()`. Introduce a + `CheckpointBackend` protocol (e.g., `save_section`, `save_vector`, + `save_metadata`, `load_section`, `load_vector`, `load_metadata`). + Shaped from day one to support both backends — the in-memory case + is the cheapest test of the abstraction's correctness; the on-disk + case is what locks in the byte-level serialisation contract. The + existing HDF5 path is refactored to fit the protocol; the wedge + is clean because section/vector ops are PETSc abstractions and the + HDF5 coupling lives in the Viewer construction. +2. **Backend implementations.** Two concrete backends: + - `InMemoryBackend` — dict of numpy arrays. Trivial once the + abstraction exists. + - `OnDiskFullStateBackend` — single monolithic HDF5 file. Shares + PETSc-state serialisation with the existing `write_checkpoint` + path (already HDF5); adds Python-state serialisation as HDF5 + attributes/groups. +3. **Adopt the serialisation contract for new solver-internal code.** + Decision: option (C) — state as first-class dataclass (see "General + serialisation contract" section). Any new algorithm-helper class + added from this point on declares its state as a separate + dataclass; the checkpoint mechanism reads `.state` automatically. + Document the pattern in the developer guide; add a check to CI that + new classes in `systems/` and `solvers/` declare `.state`. +4. **Retrofit existing solver-internal classes** to the contract. + In priority order: `DDt` (`systems/ddt.py`), parameter mutation + history (`parameters.py:145`), any solver convergence-tracking + state (audit pending). Each retrofit is small; total bounded by + the number of classes (probably under ten). +5. **Swarm `_population_generation` counter.** Bumped at the seven + identified sites; checked on restore (within-process only — see + item 6 for the cross-process semantics). +6. **Schema versioning + migration registry.** Each `State` dataclass + carries a `_schema_version` integer. A central registry maps + `(class, version)` to migration functions that lift older State + data to the current schema. In-memory restore checks version + equality (any mismatch is a programming error since both sides are + the same process); on-disk restore consults the migration registry + and applies migrations in sequence. Without this infrastructure, + on-disk snapshots break the moment any retrofitted class evolves. +7. **`Model.snapshot()` / `Model.restore()` orchestration.** Walks + registered meshes, swarms, MVs, and any object exposing a `.state` + attribute; composes the token; routes to backend (in-memory or + on-disk) based on whether `path=` is given; restores in safe order; + distinguishes within-process restore (counter validation enforced) + from cross-process restore (counters initialised from snapshot). +8. **Tests.** In-memory roundtrip on standard setups (Stokes benchmark, + swarm-bearing setup, DDt-using setup). On-disk roundtrip on the + same setups. Cross-process restore: write to disk in process A, + load in fresh process B, verify the restored model continues + bit-equivalently from the snapshot point. Verify generation + invalidation raises cleanly; verify `.state == .state` after + roundtrip for retrofitted classes. + +## Scope boundaries (NOT in v1) + +- **Mesh adaptation roundtrip.** A snapshot taken before a mesh + adaptation event cannot be restored after the adaptation — the DM + identity has changed. Documented as a contract limitation; the + generation-counter pattern detects and refuses for in-memory + restores. (For on-disk restore in a fresh process, the question + doesn't arise — the model is being initialised from the snapshot.) +- **Replacing the existing `write_timestep` path.** The selective + per-variable on-disk path continues unchanged. The new full-state + on-disk backend is additive and serves a different need (faithful + restore vs visualisation/restart-with-changes). +- **Cross-rank-count restore.** On-disk full-state snapshots written + on N ranks restore on N ranks. Restoring on a different rank count + requires the existing `write_timestep` path's interpolation + machinery and is out of scope for the full-state backend. +- **Lazy / copy-on-write in-memory tokens.** Always-eager copy for + v1. The use cases that drove this (RK staging, single-step + backtrack) all happen on a per-step cadence where eager copy is + affordable. Lazy/COW is an optimisation for later if someone proves + it matters. + +## Open implementation questions + +1. **Does PETSc expose a clean way to copy section + vector state + directly into numpy without going through a Viewer?** The PETSc + `Vec.getArray()` / `Section.getDof()` low-level APIs should + suffice, but it needs a quick prototype to confirm. Alternative: + `PETSc.Viewer.Type.MEMORY` may exist but support for DM operations + is uncertain. +2. **How does `Model.snapshot()` discover state-bearing objects?** + With the contract decision made (option C — `.state` dataclass), + discovery options narrow to: (a) walk Model's registered solvers + and recurse into anything with a `.state` attr; (b) explicit + registration at construction time. (a) is more ergonomic but + relies on solvers maintaining proper registration with the Model. + (b) is more explicit but requires every state-bearing object to + know about Model. (a) is probably right; verify the existing + solver-Model relationship can support recursive `.state` discovery. +3. **What does `Model.restore(token)` order look like?** Mesh + coords first (since MV DOFs are tied to mesh layout), then MV + DOFs, then swarm positions + migrate, then swarm-var values, then + DDt and Python-side state, then generation counters validated last. + Confirm this order is correct; particularly whether MV restoration + needs the mesh to be at restored coords first. +4. **Memory cost on realistic setups.** For a typical coupled-physics + run (mesh + swarm + several MVs + DDt history), what's the + per-snapshot byte cost? Drives the answer to "can users keep N + snapshots in memory without thinking about it." Probably bounded + by one-Stokes-solve worth of memory, but worth measuring early. + +## Status + +Audit complete. Design pending review. Implementation has not started. +The work is bounded (no open-ended research items in any of the seven +architectural-work items above) and decomposes into commits that can +land sequentially: backend abstraction → InMemoryBackend → Python-side +registration → swarm generation counter → Model orchestration → tests. + +Expected size: around four weeks of careful work, dominated by: +- the backend-abstraction refactor (item 1), which touches existing + checkpoint code other things depend on, and must be shaped to + support both backends from day one; +- the contract retrofit (item 4), mechanically small per class but + needing careful audit so no state is silently missed; +- the on-disk backend (item 2b) and schema versioning (item 6), + together adding around a week beyond the in-memory-only scope but + unlocking the cross-run / crash-recovery / bisection use cases. + +The contract decision (item 3) and the schema-versioning decision +(item 6) are the highest-leverage pieces — they set the pattern for +every algorithm-internal class added over the next few years and the +durability guarantee for every on-disk full-state snapshot. Worth +getting right even if they slow the immediate work, because the +alternative is years of half-baked snapshots that silently miss state +or break across UW3 versions. + +## Cross-references + +- `docs/developer/design/deformable_surface_metronome_design_note.md` + (in `feature/exp-integrator-freesurface` worktree) — the parent + design discussion that motivated spinning this off as a separate + capability. +- `publications/free-surface-paper/integrator_zoo_supplementary.md` — + the empirical work that exposed the need for snapshot/restore as + part of multi-stage time integration. From c9c1f247d53758bb20f1a202d5a2912762474fab Mon Sep 17 00:00:00 2001 From: lmoresi Date: Mon, 11 May 2026 21:14:20 +1000 Subject: [PATCH 02/15] checkpoint: in-memory unitary snapshot via Model.snapshot/restore (PR 1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First true unitary checkpoint in UW3. Captures mesh coordinates and mesh-variable global-vector DOFs across every registered mesh into a plain-Python token; restores back onto the same Model instance within the same process, bit-equivalent. Distinct from the existing per-variable write_timestep/read_timestep path, which continues to serve visualisation and partial restart unchanged. What lands: - src/underworld3/checkpoint/ — new module - backend.py: CheckpointBackend Protocol + InMemoryBackend (eager copy on save and load; tokens hold numpy data only, never PETSc handles, so DM-lifecycle hazards do not apply) - snapshot.py: Snapshot dataclass + snapshot()/restore() routines. Restore order: mesh coords via _deform_mesh() → MV gvec write + globalToLocal sync. Within-process invalidation gate: _mesh_version mismatch raises SnapshotInvalidatedError before any write happens. - src/underworld3/model.py — Model.snapshot()/restore() thin delegates - tests/test_0007_snapshot_inmemory.py — 6 tier-A level-1 tests: scalar/vector MV roundtrip, snapshot independence from later writes, _mesh_version invalidation, type rejection, NotImplementedError on path= (v1.1 scope). Design open-question resolutions: - Q1 module location: src/underworld3/checkpoint/ (top-level sibling, room to grow into swarm + state-as-dataclass + on-disk in v1.1+) rather than the persistence.py stub. - Q2 PETSc API for in-memory capture: Vec.array (numpy view) + subdm.createSubDM + globalToLocal is sufficient. No Viewer needed. - Q3 restore order verified empirically: mesh._deform_mesh first (rebuilds coord caches + callbacks), then per-var gvec write + globalToLocal sync; _stale_lvec flagged so downstream caches refresh. - Q4 memory budget: not yet measured; deferred to a later PR with a realistic coupled-physics setup. Deviation from PR 1 plan: the plan also mentioned refactoring Mesh.write_checkpoint() to call the new protocol. Skipped here because that path is write-only (no read_checkpoint exists), so refactoring without an exercising load-half is risk without value. The HDF5 backend lands with v1.1 on-disk full-state, where it is load-bearing and the protocol shape can be validated against both backends. Not yet covered (subsequent PRs): - swarm coverage + _population_generation counter (PR 2) - state-as-dataclass contract + DDt retrofit (PR 3) - parameter mutation history + CI check (PR 4) - on-disk full-state backend (v1.1; PR 5) - schema versioning + migration registry (PR 6) - cross-process restore + broader test suite (PR 7) See docs/developer/design/in_memory_checkpoint_design.md for the full roadmap. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/__init__.py | 1 + src/underworld3/checkpoint/__init__.py | 37 ++++ src/underworld3/checkpoint/backend.py | 80 +++++++++ src/underworld3/checkpoint/snapshot.py | 233 +++++++++++++++++++++++++ src/underworld3/model.py | 30 ++++ tests/test_0007_snapshot_inmemory.py | 115 ++++++++++++ 6 files changed, 496 insertions(+) create mode 100644 src/underworld3/checkpoint/__init__.py create mode 100644 src/underworld3/checkpoint/backend.py create mode 100644 src/underworld3/checkpoint/snapshot.py create mode 100644 tests/test_0007_snapshot_inmemory.py diff --git a/src/underworld3/__init__.py b/src/underworld3/__init__.py index ad4b80e6..d37bcb20 100644 --- a/src/underworld3/__init__.py +++ b/src/underworld3/__init__.py @@ -207,6 +207,7 @@ def view(): import underworld3.parameters import underworld3.materials import underworld3.discretisation.persistence +import underworld3.checkpoint from .model import ( Model, diff --git a/src/underworld3/checkpoint/__init__.py b/src/underworld3/checkpoint/__init__.py new file mode 100644 index 00000000..d211adaf --- /dev/null +++ b/src/underworld3/checkpoint/__init__.py @@ -0,0 +1,37 @@ +"""Unitary in-memory (and, later, on-disk) snapshot toolkit. + +The first true unitary checkpoint in Underworld3 — captures enough state +that a Model can be put back exactly as it was, suitable for backtrack on +failure, multi-stage time integration, adaptive-Δt retry, and crash +recovery. + +Distinct from the existing per-variable ``write_timestep`` / +``read_timestep`` path, which serves visualisation and partial restart. +That path stays in service of its existing role. + +See ``docs/developer/design/in_memory_checkpoint_design.md`` for the +design rationale, scope, and roadmap. In v1 (this code), only an +in-memory backend is implemented and only mesh + mesh-variable state is +captured. Subsequent PRs add swarm coverage, solver-internal Python +state (DDt history, parameter mutation history), an on-disk full-state +backend, and schema versioning across UW3 releases. +""" + +from .backend import CheckpointBackend, InMemoryBackend +from .snapshot import ( + SNAPSHOT_SCHEMA_VERSION, + Snapshot, + SnapshotInvalidatedError, + snapshot, + restore, +) + +__all__ = [ + "CheckpointBackend", + "InMemoryBackend", + "SNAPSHOT_SCHEMA_VERSION", + "Snapshot", + "SnapshotInvalidatedError", + "snapshot", + "restore", +] diff --git a/src/underworld3/checkpoint/backend.py b/src/underworld3/checkpoint/backend.py new file mode 100644 index 00000000..271e486d --- /dev/null +++ b/src/underworld3/checkpoint/backend.py @@ -0,0 +1,80 @@ +"""Storage protocol for snapshot tokens. + +The protocol is shaped from day one to support both an in-memory backend +(dict of numpy arrays, v1) and an on-disk full-state backend (HDF5, +v1.1). Per the design note, the in-memory backend is the cheapest +correctness test of the abstraction; the on-disk backend locks in the +byte-level serialisation contract. +""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +import numpy as np + + +@runtime_checkable +class CheckpointBackend(Protocol): + """Backing-store interface for :class:`underworld3.checkpoint.Snapshot`. + + Implementations + --------------- + - :class:`InMemoryBackend` — v1; numpy arrays held in process memory. + - ``OnDiskFullStateBackend`` — v1.1; single monolithic HDF5 file. + + Vectors are bulk numerical data; metadata is small scalars / dicts / + lists describing structure and provenance. + """ + + def save_vector(self, key: str, array: np.ndarray) -> None: ... + + def load_vector(self, key: str) -> np.ndarray: ... + + def save_metadata(self, key: str, value: Any) -> None: ... + + def load_metadata(self, key: str) -> Any: ... + + def list_vectors(self) -> list[str]: ... + + def list_metadata(self) -> list[str]: ... + + +class InMemoryBackend: + """Snapshot storage in process memory. + + Eager-copy on both ``save_vector`` and ``load_vector`` per the v1 + scope-boundary (no lazy / copy-on-write semantics). Per-snapshot + byte cost is the sum of captured vector sizes — expected to be + bounded by one Stokes solve's working memory for typical setups. + """ + + def __init__(self) -> None: + self._vectors: dict[str, np.ndarray] = {} + self._metadata: dict[str, Any] = {} + + def save_vector(self, key: str, array: np.ndarray) -> None: + if key in self._vectors: + raise KeyError(f"vector key already present in snapshot: {key!r}") + self._vectors[key] = np.asarray(array).copy() + + def load_vector(self, key: str) -> np.ndarray: + if key not in self._vectors: + raise KeyError(f"vector key not in snapshot: {key!r}") + return self._vectors[key].copy() + + def save_metadata(self, key: str, value: Any) -> None: + if key in self._metadata: + raise KeyError(f"metadata key already present in snapshot: {key!r}") + self._metadata[key] = value + + def load_metadata(self, key: str) -> Any: + if key not in self._metadata: + raise KeyError(f"metadata key not in snapshot: {key!r}") + return self._metadata[key] + + def list_vectors(self) -> list[str]: + return list(self._vectors.keys()) + + def list_metadata(self) -> list[str]: + return list(self._metadata.keys()) diff --git a/src/underworld3/checkpoint/snapshot.py b/src/underworld3/checkpoint/snapshot.py new file mode 100644 index 00000000..e4a20fb3 --- /dev/null +++ b/src/underworld3/checkpoint/snapshot.py @@ -0,0 +1,233 @@ +"""Unitary state capture and restore. + +A :class:`Snapshot` is a plain-Python token holding numpy data and small +metadata. It contains no PETSc Vec or DM handles, so it survives object +lifecycle changes within a process. Within a process, a snapshot can be +restored back onto the same :class:`underworld3.Model` instance; across +processes (v1.1, on-disk backend) the model is initialised from the +snapshot rather than restored to a previous state. + +This module implements the v1 scope: mesh coordinates and mesh-variable +DOFs. Swarm coverage, solver-internal Python state, on-disk backend, +schema versioning, and cross-process restore are scheduled for follow-up +PRs per the design note. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +import numpy as np + +from .backend import CheckpointBackend, InMemoryBackend + + +SNAPSHOT_SCHEMA_VERSION = 1 + + +class SnapshotInvalidatedError(RuntimeError): + """Raised when a snapshot can no longer be restored faithfully. + + Triggers in v1: mesh ``_mesh_version`` differs from the snapshot + (mesh has been adapted; DM identity has changed), or a registered + mesh / mesh-variable named in the snapshot is no longer present on + the target :class:`underworld3.Model`. + + Future triggers (subsequent PRs): swarm population-generation + counter mismatch, on-disk schema version that has no migration + path. + """ + + +@dataclass +class Snapshot: + """Unitary state token. + + Produced by :func:`snapshot`; consumed by :func:`restore`. Holds a + backend (where the bulk arrays live) plus per-Model bookkeeping — + which meshes were captured, which mesh variables were captured + under each mesh, and the mesh-version counters that gate + within-process restore. + + Attributes + ---------- + backend + Where the captured arrays live. v1 always uses + :class:`InMemoryBackend`; v1.1 will add on-disk backends. + schema_version + Snapshot file-format version. Restore refuses on mismatch in + v1; v1.1's migration registry will lift older versions to the + current schema for on-disk restore only. + mesh_keys + Stable ordering of captured mesh identifiers (``id(mesh)``); + determines restore order. + mesh_versions + Per-mesh ``_mesh_version`` at the moment of capture. Restore + compares against the current value; mismatch ⇒ + :class:`SnapshotInvalidatedError`. + meshvar_names + Mapping ``mesh_id → [var.clean_name, ...]`` — the mesh + variables captured for that mesh, in capture order. + metadata + User-visible bookkeeping (simulation time, step counter, free + text). Not load-bearing for restore correctness. + """ + + backend: CheckpointBackend + schema_version: int = SNAPSHOT_SCHEMA_VERSION + mesh_keys: list[int] = field(default_factory=list) + mesh_versions: dict[int, int] = field(default_factory=dict) + meshvar_names: dict[int, list[str]] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + +def _mesh_coords_key(mesh_id: int) -> str: + return f"mesh:{mesh_id}:coords" + + +def _meshvar_key(mesh_id: int, var_clean_name: str) -> str: + return f"mesh:{mesh_id}:var:{var_clean_name}:gvec" + + +def snapshot(model, *, path: Optional[str] = None) -> Snapshot: + """Capture a unitary snapshot of the model's current state. + + Parameters + ---------- + model + The :class:`underworld3.Model` whose registered meshes and + mesh variables should be captured. + path + Reserved for the v1.1 on-disk backend. Passing a non-``None`` + value raises :class:`NotImplementedError` in v1. + + Returns + ------- + Snapshot + Token suitable for passing to :func:`restore` on the same + ``model`` instance within the same process. v1 captures mesh + coordinates and mesh-variable global-vector DOF values. + """ + if path is not None: + raise NotImplementedError( + "on-disk full-state snapshot is scheduled for v1.1; " + "v1 supports the in-memory backend only" + ) + + snap = Snapshot(backend=InMemoryBackend()) + for mesh_id, mesh in list(model._meshes.items()): + _capture_mesh(snap, mesh_id, mesh) + return snap + + +def _capture_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: + if mesh_id in snap.mesh_keys: + return + snap.mesh_keys.append(mesh_id) + snap.mesh_versions[mesh_id] = int(getattr(mesh, "_mesh_version", 0)) + + coords = np.asarray(mesh.X.coords) + snap.backend.save_vector(_mesh_coords_key(mesh_id), coords) + + var_names: list[str] = [] + for var in mesh.vars.values(): + var._sync_lvec_to_gvec() + gvec_array = np.asarray(var._gvec.array) + snap.backend.save_vector(_meshvar_key(mesh_id, var.clean_name), gvec_array) + var_names.append(var.clean_name) + snap.meshvar_names[mesh_id] = var_names + + +def restore(model, snap: Snapshot) -> None: + """Restore the model from a snapshot. + + Restore order (within-process; cross-process is v1.1): + + 1. Mesh coordinates (via :meth:`Mesh._deform_mesh`, which rebuilds + coordinate caches and notifies registered callbacks). + 2. Mesh-variable DOFs (global vector written, then synced to local + vector via ``subdm.globalToLocal``). + 3. ``_mesh_version`` is verified equal to the capture value before + any write; mismatch raises :class:`SnapshotInvalidatedError`. + + Future PRs extend the order to: swarm positions + migrate → swarm + variable values → solver-internal Python state (DDt history, + parameter mutation history) → generation-counter validation last. + + Parameters + ---------- + model + The :class:`underworld3.Model` to restore. Must be the same + instance the snapshot came from (within-process restore). + snap + Token returned by :func:`snapshot`. + + Raises + ------ + SnapshotInvalidatedError + Mesh ``_mesh_version`` has changed since capture, or a + captured mesh / variable is no longer registered on the model. + TypeError + ``snap`` is not a :class:`Snapshot`. + """ + if not isinstance(snap, Snapshot): + raise TypeError( + f"expected underworld3.checkpoint.Snapshot, got {type(snap).__name__}" + ) + if snap.schema_version != SNAPSHOT_SCHEMA_VERSION: + raise SnapshotInvalidatedError( + f"snapshot schema version {snap.schema_version} does not match " + f"current {SNAPSHOT_SCHEMA_VERSION}; on-disk migration is v1.1" + ) + + for mesh_id in snap.mesh_keys: + mesh = model._meshes.get(mesh_id) + if mesh is None: + raise SnapshotInvalidatedError( + f"mesh id {mesh_id} from snapshot is not registered on this " + f"Model; within-process restore requires the originating Model" + ) + current_version = int(getattr(mesh, "_mesh_version", 0)) + captured_version = snap.mesh_versions[mesh_id] + if current_version != captured_version: + raise SnapshotInvalidatedError( + f"mesh._mesh_version moved from {captured_version} to " + f"{current_version} since snapshot — likely mesh.adapt() or " + f"deform_mesh() invalidated the DM identity" + ) + _restore_mesh(snap, mesh_id, mesh) + + +def _restore_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: + coords = snap.backend.load_vector(_mesh_coords_key(mesh_id)) + expected_shape = np.asarray(mesh.X.coords).shape + if coords.shape != expected_shape: + raise SnapshotInvalidatedError( + f"mesh coordinate shape changed: snapshot {coords.shape} vs " + f"current {expected_shape}" + ) + mesh._deform_mesh(coords) + + current_vars = {var.clean_name: var for var in mesh.vars.values()} + for var_clean_name in snap.meshvar_names[mesh_id]: + var = current_vars.get(var_clean_name) + if var is None: + raise SnapshotInvalidatedError( + f"mesh variable {var_clean_name!r} from snapshot is not " + f"present on mesh; restore requires the same variable set" + ) + var._sync_lvec_to_gvec() # ensures _gvec exists with a current size + saved = snap.backend.load_vector(_meshvar_key(mesh_id, var_clean_name)) + current_shape = np.asarray(var._gvec.array).shape + if saved.shape != current_shape: + raise SnapshotInvalidatedError( + f"variable {var_clean_name!r} gvec shape changed: snapshot " + f"{saved.shape} vs current {current_shape}" + ) + var._gvec.array[...] = saved + iset, subdm = mesh.dm.createSubDM(var.field_id) + subdm.globalToLocal(var._gvec, var._lvec, addv=False) + iset.destroy() + subdm.destroy() + mesh._stale_lvec = True diff --git a/src/underworld3/model.py b/src/underworld3/model.py index 8e1801d2..ca884569 100644 --- a/src/underworld3/model.py +++ b/src/underworld3/model.py @@ -565,6 +565,36 @@ def get_solver(self, name: str): """Get a solver by name from the model registry""" return self._solvers.get(name) + def snapshot(self, *, path: Optional[str] = None): + """Capture a unitary in-memory snapshot of the model's state. + + v1 covers mesh coordinates and mesh-variable DOFs across every + registered mesh. Subsequent PRs extend coverage to swarms and + solver-internal Python state. + + Pass ``path=...`` to write to an HDF5 file once the on-disk + backend lands (v1.1); v1 raises ``NotImplementedError``. + + See ``docs/developer/design/in_memory_checkpoint_design.md`` + for the full design. + """ + from underworld3.checkpoint import snapshot as _snapshot + + return _snapshot(self, path=path) + + def restore(self, snap) -> None: + """Restore the model from a :class:`underworld3.checkpoint.Snapshot`. + + Within-process restore: ``snap`` must have been produced by + :meth:`snapshot` on this same ``Model`` instance. Raises + :class:`underworld3.checkpoint.SnapshotInvalidatedError` if + the mesh has been adapted, or a captured mesh / variable is + no longer registered. + """ + from underworld3.checkpoint import restore as _restore + + return _restore(self, snap) + def define_parameter(self, name: str, ptype=None, **kwargs): """ Define a new parameter with validation rules. diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py new file mode 100644 index 00000000..71dfbb43 --- /dev/null +++ b/tests/test_0007_snapshot_inmemory.py @@ -0,0 +1,115 @@ +import pytest + +pytestmark = [pytest.mark.level_1, pytest.mark.tier_a] + +import numpy as np + + +def _fresh_model_and_mesh(): + import underworld3 as uw + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 8.0 + ) + return uw, model, mesh + + +def test_meshvariable_in_memory_roundtrip(): + """Snapshot, scribble, restore: the MV global vector is recovered exactly.""" + uw, model, mesh = _fresh_model_and_mesh() + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + + T.array[:, 0, 0] = T.coords[:, 0] + 2.0 * T.coords[:, 1] + pre_array = np.asarray(T.array[...]).copy() + + snap = model.snapshot() + + T.array[...] = -42.0 + assert not np.allclose(np.asarray(T.array[...]), pre_array), "scribble didn't take" + + model.restore(snap) + + assert np.allclose(np.asarray(T.array[...]), pre_array, atol=0.0, rtol=0.0), ( + "MeshVariable.array is not bit-equivalent after restore" + ) + + +def test_multiple_meshvariables_roundtrip(): + """All MVs on a mesh are captured and restored, not just the first.""" + uw, model, mesh = _fresh_model_and_mesh() + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + V = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + + T.array[:, 0, 0] = 1.5 * T.coords[:, 0] + V.array[:, 0, 0] = 3.0 * V.coords[:, 0] + V.array[:, 0, 1] = 7.0 * V.coords[:, 1] + + T_pre = np.asarray(T.array[...]).copy() + V_pre = np.asarray(V.array[...]).copy() + + snap = model.snapshot() + + T.array[...] = 0.0 + V.array[...] = 0.0 + + model.restore(snap) + + assert np.allclose(np.asarray(T.array[...]), T_pre) + assert np.allclose(np.asarray(V.array[...]), V_pre) + + +def test_snapshot_is_independent_of_subsequent_writes(): + """Captured array is a copy; writes to the live MV don't leak into the snapshot.""" + uw, model, mesh = _fresh_model_and_mesh() + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + T.array[:, 0, 0] = 5.0 + + snap = model.snapshot() + T.array[...] = -1.0 + + # The backend still holds the captured value, not the post-write value. + keys = snap.backend.list_vectors() + var_key = next(k for k in keys if "var:T" in k) + captured = snap.backend.load_vector(var_key) + assert np.allclose(captured, 5.0), ( + "in-memory backend did not isolate the captured array from later writes" + ) + + +def test_mesh_version_invalidates_restore(): + """A bumped _mesh_version makes restore refuse rather than silently corrupt.""" + import underworld3 as uw + from underworld3.checkpoint import SnapshotInvalidatedError + + uw_, model, mesh = _fresh_model_and_mesh() + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + T.array[:, 0, 0] = 1.0 + + snap = model.snapshot() + + # Simulate a mesh-mutation event (e.g. adapt(), or any deformation + # routed through the high-level callback that bumps _mesh_version). + mesh._mesh_version += 1 + + with pytest.raises(SnapshotInvalidatedError, match="_mesh_version"): + model.restore(snap) + + +def test_restore_rejects_non_snapshot(): + """A bare dict / array is not a Snapshot; restore raises TypeError.""" + uw, model, mesh = _fresh_model_and_mesh() + _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + + with pytest.raises(TypeError): + model.restore({"not": "a snapshot"}) + + +def test_snapshot_path_is_v1_1_scope(): + """Passing path= raises NotImplementedError until the on-disk backend lands.""" + uw, model, mesh = _fresh_model_and_mesh() + _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + + with pytest.raises(NotImplementedError): + model.snapshot(path="/tmp/should_not_be_written.h5") From 001f961679a7f0ab446f8b5b5993549c8aa36b69 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Mon, 11 May 2026 21:28:30 +1000 Subject: [PATCH 03/15] checkpoint: swarm coverage + _population_generation counter (PR 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the unitary snapshot to capture per-rank swarm positions and user swarm-variable data. Adds Swarm._population_generation, a counter bumped at every particle-population mutation site, used as the within-process invalidation gate: restoring a snapshot taken before a populate / migrate / add_particles / remesh event raises SnapshotInvalidatedError rather than silently corrupting a now-stale position array. Counter init + bump sites (src/underworld3/swarm.py): - Swarm.__init__: initialise to 0 next to _mesh_version. - populate(): bump once at the end (covers the 1-3 internal addNPoints calls in the populate body). - Swarm.migrate() after the migration_disabled early-exit: bump unconditionally; conservative even when migrate is a no-op, because under-bumping risks silent corruption while over-bumping is safe. - add_particles_with_coordinates() after its direct self.dm.migrate(): this path doesn't go through Swarm.migrate so we bump explicitly. - add_particles_with_global_coordinates() right after addNPoints: catches the migrate=False case too; the migrate=True path will double-bump via Swarm.migrate, which is fine. - advection() remesh path after the addNPoints reinjection. Snapshot extensions (src/underworld3/checkpoint/snapshot.py): - New fields: swarm_keys, swarm_generations, swarm_mesh_versions, swarmvar_names. - _capture_swarm: reads DMSwarmPIC_coor via dm.getField → copy → restoreField; iterates swarm.vars excluding DMSwarm* internals; records both _population_generation and _mesh_version. - _restore_swarm: validates both counters before any write; writes back positions + per-var data in place. Deliberately bypasses populate/add_particles/migrate so the restore itself does not bump the counter or mutate the population we just confirmed stable. Test coverage (tests/test_0007_snapshot_inmemory.py): 5 new tests on top of the 6 mesh-only tests: - swarm positions + user-variable roundtrip after scribble - counter bumps on populate, migrate, add_particles_with_coordinates, add_particles_with_global_coordinates (in monotonic order) - migrate-between-snapshot-and-restore raises SnapshotInvalidatedError - add_particles-between-snapshot-and-restore raises likewise - DMSwarm_* internal variables stay out of the captured key set Not yet covered: cross-process restore (v1.1), advection remesh invalidation test (needs a recycle-enabled swarm + a velocity field, larger setup than belongs in the core 0007 file). Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/snapshot.py | 127 +++++++++++++++++++++++-- src/underworld3/swarm.py | 26 +++++ tests/test_0007_snapshot_inmemory.py | 101 ++++++++++++++++++++ 3 files changed, 247 insertions(+), 7 deletions(-) diff --git a/src/underworld3/checkpoint/snapshot.py b/src/underworld3/checkpoint/snapshot.py index e4a20fb3..18a61584 100644 --- a/src/underworld3/checkpoint/snapshot.py +++ b/src/underworld3/checkpoint/snapshot.py @@ -30,13 +30,14 @@ class SnapshotInvalidatedError(RuntimeError): """Raised when a snapshot can no longer be restored faithfully. Triggers in v1: mesh ``_mesh_version`` differs from the snapshot - (mesh has been adapted; DM identity has changed), or a registered - mesh / mesh-variable named in the snapshot is no longer present on - the target :class:`underworld3.Model`. - - Future triggers (subsequent PRs): swarm population-generation - counter mismatch, on-disk schema version that has no migration - path. + (mesh has been adapted; DM identity has changed); swarm + ``_population_generation`` differs (a populate / migrate / + add_particles / remesh event ran since capture); or a captured + mesh / variable / swarm is no longer registered on the target + :class:`underworld3.Model`. + + Future triggers (subsequent PRs): on-disk schema version that has + no migration path. """ @@ -79,6 +80,10 @@ class Snapshot: mesh_keys: list[int] = field(default_factory=list) mesh_versions: dict[int, int] = field(default_factory=dict) meshvar_names: dict[int, list[str]] = field(default_factory=dict) + swarm_keys: list[int] = field(default_factory=list) + swarm_generations: dict[int, int] = field(default_factory=dict) + swarm_mesh_versions: dict[int, int] = field(default_factory=dict) + swarmvar_names: dict[int, list[str]] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) @@ -90,6 +95,25 @@ def _meshvar_key(mesh_id: int, var_clean_name: str) -> str: return f"mesh:{mesh_id}:var:{var_clean_name}:gvec" +def _swarm_coords_key(swarm_id: int) -> str: + return f"swarm:{swarm_id}:coords" + + +def _swarmvar_key(swarm_id: int, var_clean_name: str) -> str: + return f"swarm:{swarm_id}:var:{var_clean_name}:data" + + +def _is_internal_swarmvar(var_name: str) -> bool: + """Filter PETSc-managed internal swarm variables from user capture. + + ``DMSwarmPIC_coor`` is captured separately via the particle-coords + path. ``DMSwarm_X0`` and ``DMSwarm_remeshed`` carry recycle-related + bookkeeping that is regenerated on next solve and is out of scope + for v1 capture. + """ + return var_name.startswith("DMSwarm") + + def snapshot(model, *, path: Optional[str] = None) -> Snapshot: """Capture a unitary snapshot of the model's current state. @@ -118,6 +142,8 @@ def snapshot(model, *, path: Optional[str] = None) -> Snapshot: snap = Snapshot(backend=InMemoryBackend()) for mesh_id, mesh in list(model._meshes.items()): _capture_mesh(snap, mesh_id, mesh) + for swarm_id, swarm in list(model._swarms.items()): + _capture_swarm(snap, swarm_id, swarm) return snap @@ -139,6 +165,27 @@ def _capture_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: snap.meshvar_names[mesh_id] = var_names +def _capture_swarm(snap: Snapshot, swarm_id: int, swarm) -> None: + if swarm_id in snap.swarm_keys: + return + snap.swarm_keys.append(swarm_id) + snap.swarm_generations[swarm_id] = int(swarm._population_generation) + snap.swarm_mesh_versions[swarm_id] = int(getattr(swarm, "_mesh_version", 0)) + + coords = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)).copy() + swarm.dm.restoreField("DMSwarmPIC_coor") + snap.backend.save_vector(_swarm_coords_key(swarm_id), coords) + + var_names: list[str] = [] + for var in list(swarm.vars.values()): + if _is_internal_swarmvar(var.name): + continue + data = np.asarray(var.data).copy() + snap.backend.save_vector(_swarmvar_key(swarm_id, var.clean_name), data) + var_names.append(var.clean_name) + snap.swarmvar_names[swarm_id] = var_names + + def restore(model, snap: Snapshot) -> None: """Restore the model from a snapshot. @@ -198,6 +245,32 @@ def restore(model, snap: Snapshot) -> None: ) _restore_mesh(snap, mesh_id, mesh) + for swarm_id in snap.swarm_keys: + swarm = model._swarms.get(swarm_id) + if swarm is None: + raise SnapshotInvalidatedError( + f"swarm id {swarm_id} from snapshot is not registered on " + f"this Model; within-process restore requires the originating " + f"Model" + ) + current_gen = int(swarm._population_generation) + captured_gen = snap.swarm_generations[swarm_id] + if current_gen != captured_gen: + raise SnapshotInvalidatedError( + f"swarm _population_generation moved from {captured_gen} " + f"to {current_gen} since snapshot — populate/migrate/" + f"add_particles/remesh ran between snapshot and restore" + ) + current_mv = int(getattr(swarm, "_mesh_version", 0)) + captured_mv = snap.swarm_mesh_versions[swarm_id] + if current_mv != captured_mv: + raise SnapshotInvalidatedError( + f"swarm._mesh_version moved from {captured_mv} to {current_mv} " + f"since snapshot — the parent mesh changed and the swarm " + f"would need to re-migrate to be consistent" + ) + _restore_swarm(snap, swarm_id, swarm) + def _restore_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: coords = snap.backend.load_vector(_mesh_coords_key(mesh_id)) @@ -231,3 +304,43 @@ def _restore_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: iset.destroy() subdm.destroy() mesh._stale_lvec = True + + +def _restore_swarm(snap: Snapshot, swarm_id: int, swarm) -> None: + """Write captured particle positions and user-var values back to the swarm. + + The population-generation counter has already been verified equal + by the caller, so per-rank array sizes match the captured arrays + and we can write in place. We deliberately bypass ``populate`` / + ``add_particles_*`` / ``migrate`` because invoking them would bump + the counter and (more importantly) mutate the population we just + confirmed to be stable. + """ + saved_coords = snap.backend.load_vector(_swarm_coords_key(swarm_id)) + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) + if coord_field.shape != saved_coords.shape: + swarm.dm.restoreField("DMSwarmPIC_coor") + raise SnapshotInvalidatedError( + f"swarm particle-coord shape changed: snapshot {saved_coords.shape} " + f"vs current {coord_field.shape} — population identity differs even " + f"though _population_generation matched; this is a programming error" + ) + coord_field[...] = saved_coords + swarm.dm.restoreField("DMSwarmPIC_coor") + + current_vars = {var.clean_name: var for var in swarm.vars.values()} + for var_clean_name in snap.swarmvar_names[swarm_id]: + var = current_vars.get(var_clean_name) + if var is None: + raise SnapshotInvalidatedError( + f"swarm variable {var_clean_name!r} from snapshot is not " + f"present on this swarm; restore requires the same variable set" + ) + saved = snap.backend.load_vector(_swarmvar_key(swarm_id, var_clean_name)) + current = np.asarray(var.data) + if saved.shape != current.shape: + raise SnapshotInvalidatedError( + f"swarm variable {var_clean_name!r} data shape changed: " + f"snapshot {saved.shape} vs current {current.shape}" + ) + current[...] = saved diff --git a/src/underworld3/swarm.py b/src/underworld3/swarm.py index ff8dc593..50e0f945 100644 --- a/src/underworld3/swarm.py +++ b/src/underworld3/swarm.py @@ -2493,6 +2493,12 @@ def __init__(self, mesh, recycle_rate=0, verbose=False, clip_to_mesh=True): # Mesh version tracking for coordinate change detection self._mesh_version = mesh._mesh_version + # Snapshot/restore invalidation counter: bumped on every + # particle-population mutation (populate, add_particles_*, + # migrate, advection remesh). See + # docs/developer/design/in_memory_checkpoint_design.md. + self._population_generation = 0 + # Register this swarm with the mesh for coordinate change notifications mesh.register_swarm(self) @@ -3286,6 +3292,9 @@ def populate( offset = swarm_orig_size * i self._remeshed.data[offset::, 0] = i + # Snapshot invalidation: particle population just changed. + self._population_generation += 1 + return @timing.routine_timer_decorator @@ -3315,6 +3324,12 @@ def migrate( if self._migration_disabled: return + # Snapshot invalidation: migration may move or drop particles, + # changing per-rank population identity. Conservative bump even + # if the call is ultimately a no-op — over-bumping is safe, + # under-bumping risks silent corruption on restore. + self._population_generation += 1 + from time import time if delete_lost_points is None: @@ -3536,6 +3551,10 @@ def add_particles_with_coordinates(self, coordinatesArray) -> int: if hasattr(var, "_canonical_data"): var._canonical_data = None + # Snapshot invalidation: addNPoints + dm.migrate is a direct + # PETSc call path that does not go through Swarm.migrate. + self._population_generation += 1 + return npoints @timing.routine_timer_decorator @@ -3604,6 +3623,10 @@ def add_particles_with_global_coordinates( self.dm.finalizeFieldRegister() self.dm.addNPoints(npoints=npoints) + # Snapshot invalidation: population changed even if the caller + # opts out of the post-add migration (migrate=False). + self._population_generation += 1 + # Add new points with provided coords # Record the current rank (migration needs to know where we start from !) @@ -4470,6 +4493,9 @@ def advection( self.dm.addNPoints(num_remeshed_points) + # Snapshot invalidation: remesh just re-injected particles. + self._population_generation += 1 + ## cellid = self.dm.getField("DMSwarm_cellid") coords = self.dm.getField("DMSwarmPIC_coor").reshape((-1, self.dim)) rmsh = self.dm.getField("DMSwarm_remeshed") diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index 71dfbb43..edcfb3b0 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -113,3 +113,104 @@ def test_snapshot_path_is_v1_1_scope(): with pytest.raises(NotImplementedError): model.snapshot(path="/tmp/should_not_be_written.h5") + + +# ----- Swarm coverage ----- + + +def _fresh_model_mesh_and_swarm(with_material=True): + """Create a fresh model + mesh + swarm. Swarm-variable creation must + happen before populate(), so we build everything in one place. + """ + import underworld3 as uw + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + swarm = uw.swarm.Swarm(mesh) + material = None + if with_material: + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + return uw, model, mesh, swarm, material + + +def test_swarm_positions_and_variable_roundtrip(): + """Snapshot, scramble swarm positions + svar, restore: both come back.""" + uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() + + coords = swarm._particle_coordinates.data + material.data[:, 0] = 0.5 * coords[:, 0] + coords[:, 1] + coords_pre = coords.copy() + material_pre = np.asarray(material.data).copy() + + snap = model.snapshot() + + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) + coord_field[...] = -99.0 + swarm.dm.restoreField("DMSwarmPIC_coor") + material.data[...] = -99.0 + + model.restore(snap) + + assert np.allclose(swarm._particle_coordinates.data, coords_pre) + assert np.allclose(np.asarray(material.data), material_pre) + + +def test_swarm_population_generation_starts_at_zero_and_bumps(): + """Sanity-check the counter bumps on each mutation category.""" + uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm(with_material=False) + after_populate = swarm._population_generation + swarm.migrate(remove_sent_points=True) + after_migrate = swarm._population_generation + swarm.add_particles_with_coordinates(np.array([[0.5, 0.5]])) + after_add_local = swarm._population_generation + swarm.add_particles_with_global_coordinates(np.array([[0.25, 0.25]])) + after_add_global = swarm._population_generation + + assert after_populate >= 1 + assert after_migrate > after_populate + assert after_add_local > after_migrate + assert after_add_global > after_add_local + + +def test_swarm_migrate_invalidates_restore(): + """A migrate() call between snapshot and restore makes restore refuse.""" + from underworld3.checkpoint import SnapshotInvalidatedError + + uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() + + snap = model.snapshot() + swarm.migrate(remove_sent_points=True) + + with pytest.raises(SnapshotInvalidatedError, match="_population_generation"): + model.restore(snap) + + +def test_swarm_add_particles_invalidates_restore(): + """add_particles_with_coordinates between snapshot and restore raises.""" + from underworld3.checkpoint import SnapshotInvalidatedError + + uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() + + snap = model.snapshot() + swarm.add_particles_with_coordinates(np.array([[0.5, 0.5]])) + + with pytest.raises(SnapshotInvalidatedError, match="_population_generation"): + model.restore(snap) + + +def test_swarm_internal_variables_are_not_captured(): + """Internal DMSwarm_* variables stay out of the snapshot key list.""" + uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() + + snap = model.snapshot() + keys = snap.backend.list_vectors() + swarmvar_keys = [k for k in keys if k.startswith(f"swarm:{id(swarm)}:var:")] + captured_names = {k.split(":var:")[1].split(":data")[0] for k in swarmvar_keys} + + # User variable present, PETSc-internal ones absent. + assert "material" in captured_names + assert not any(n.startswith("DMSwarm") for n in captured_names) From 6d1b63548b9c0fb2298d56869e4d51597bdd2bc7 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Mon, 11 May 2026 22:35:21 +1000 Subject: [PATCH 04/15] Revert "checkpoint: swarm coverage + _population_generation counter (PR 2)" This reverts commit 001f961679a7f0ab446f8b5b5993549c8aa36b69. --- src/underworld3/checkpoint/snapshot.py | 127 ++----------------------- src/underworld3/swarm.py | 26 ----- tests/test_0007_snapshot_inmemory.py | 101 -------------------- 3 files changed, 7 insertions(+), 247 deletions(-) diff --git a/src/underworld3/checkpoint/snapshot.py b/src/underworld3/checkpoint/snapshot.py index 18a61584..e4a20fb3 100644 --- a/src/underworld3/checkpoint/snapshot.py +++ b/src/underworld3/checkpoint/snapshot.py @@ -30,14 +30,13 @@ class SnapshotInvalidatedError(RuntimeError): """Raised when a snapshot can no longer be restored faithfully. Triggers in v1: mesh ``_mesh_version`` differs from the snapshot - (mesh has been adapted; DM identity has changed); swarm - ``_population_generation`` differs (a populate / migrate / - add_particles / remesh event ran since capture); or a captured - mesh / variable / swarm is no longer registered on the target - :class:`underworld3.Model`. - - Future triggers (subsequent PRs): on-disk schema version that has - no migration path. + (mesh has been adapted; DM identity has changed), or a registered + mesh / mesh-variable named in the snapshot is no longer present on + the target :class:`underworld3.Model`. + + Future triggers (subsequent PRs): swarm population-generation + counter mismatch, on-disk schema version that has no migration + path. """ @@ -80,10 +79,6 @@ class Snapshot: mesh_keys: list[int] = field(default_factory=list) mesh_versions: dict[int, int] = field(default_factory=dict) meshvar_names: dict[int, list[str]] = field(default_factory=dict) - swarm_keys: list[int] = field(default_factory=list) - swarm_generations: dict[int, int] = field(default_factory=dict) - swarm_mesh_versions: dict[int, int] = field(default_factory=dict) - swarmvar_names: dict[int, list[str]] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) @@ -95,25 +90,6 @@ def _meshvar_key(mesh_id: int, var_clean_name: str) -> str: return f"mesh:{mesh_id}:var:{var_clean_name}:gvec" -def _swarm_coords_key(swarm_id: int) -> str: - return f"swarm:{swarm_id}:coords" - - -def _swarmvar_key(swarm_id: int, var_clean_name: str) -> str: - return f"swarm:{swarm_id}:var:{var_clean_name}:data" - - -def _is_internal_swarmvar(var_name: str) -> bool: - """Filter PETSc-managed internal swarm variables from user capture. - - ``DMSwarmPIC_coor`` is captured separately via the particle-coords - path. ``DMSwarm_X0`` and ``DMSwarm_remeshed`` carry recycle-related - bookkeeping that is regenerated on next solve and is out of scope - for v1 capture. - """ - return var_name.startswith("DMSwarm") - - def snapshot(model, *, path: Optional[str] = None) -> Snapshot: """Capture a unitary snapshot of the model's current state. @@ -142,8 +118,6 @@ def snapshot(model, *, path: Optional[str] = None) -> Snapshot: snap = Snapshot(backend=InMemoryBackend()) for mesh_id, mesh in list(model._meshes.items()): _capture_mesh(snap, mesh_id, mesh) - for swarm_id, swarm in list(model._swarms.items()): - _capture_swarm(snap, swarm_id, swarm) return snap @@ -165,27 +139,6 @@ def _capture_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: snap.meshvar_names[mesh_id] = var_names -def _capture_swarm(snap: Snapshot, swarm_id: int, swarm) -> None: - if swarm_id in snap.swarm_keys: - return - snap.swarm_keys.append(swarm_id) - snap.swarm_generations[swarm_id] = int(swarm._population_generation) - snap.swarm_mesh_versions[swarm_id] = int(getattr(swarm, "_mesh_version", 0)) - - coords = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)).copy() - swarm.dm.restoreField("DMSwarmPIC_coor") - snap.backend.save_vector(_swarm_coords_key(swarm_id), coords) - - var_names: list[str] = [] - for var in list(swarm.vars.values()): - if _is_internal_swarmvar(var.name): - continue - data = np.asarray(var.data).copy() - snap.backend.save_vector(_swarmvar_key(swarm_id, var.clean_name), data) - var_names.append(var.clean_name) - snap.swarmvar_names[swarm_id] = var_names - - def restore(model, snap: Snapshot) -> None: """Restore the model from a snapshot. @@ -245,32 +198,6 @@ def restore(model, snap: Snapshot) -> None: ) _restore_mesh(snap, mesh_id, mesh) - for swarm_id in snap.swarm_keys: - swarm = model._swarms.get(swarm_id) - if swarm is None: - raise SnapshotInvalidatedError( - f"swarm id {swarm_id} from snapshot is not registered on " - f"this Model; within-process restore requires the originating " - f"Model" - ) - current_gen = int(swarm._population_generation) - captured_gen = snap.swarm_generations[swarm_id] - if current_gen != captured_gen: - raise SnapshotInvalidatedError( - f"swarm _population_generation moved from {captured_gen} " - f"to {current_gen} since snapshot — populate/migrate/" - f"add_particles/remesh ran between snapshot and restore" - ) - current_mv = int(getattr(swarm, "_mesh_version", 0)) - captured_mv = snap.swarm_mesh_versions[swarm_id] - if current_mv != captured_mv: - raise SnapshotInvalidatedError( - f"swarm._mesh_version moved from {captured_mv} to {current_mv} " - f"since snapshot — the parent mesh changed and the swarm " - f"would need to re-migrate to be consistent" - ) - _restore_swarm(snap, swarm_id, swarm) - def _restore_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: coords = snap.backend.load_vector(_mesh_coords_key(mesh_id)) @@ -304,43 +231,3 @@ def _restore_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: iset.destroy() subdm.destroy() mesh._stale_lvec = True - - -def _restore_swarm(snap: Snapshot, swarm_id: int, swarm) -> None: - """Write captured particle positions and user-var values back to the swarm. - - The population-generation counter has already been verified equal - by the caller, so per-rank array sizes match the captured arrays - and we can write in place. We deliberately bypass ``populate`` / - ``add_particles_*`` / ``migrate`` because invoking them would bump - the counter and (more importantly) mutate the population we just - confirmed to be stable. - """ - saved_coords = snap.backend.load_vector(_swarm_coords_key(swarm_id)) - coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) - if coord_field.shape != saved_coords.shape: - swarm.dm.restoreField("DMSwarmPIC_coor") - raise SnapshotInvalidatedError( - f"swarm particle-coord shape changed: snapshot {saved_coords.shape} " - f"vs current {coord_field.shape} — population identity differs even " - f"though _population_generation matched; this is a programming error" - ) - coord_field[...] = saved_coords - swarm.dm.restoreField("DMSwarmPIC_coor") - - current_vars = {var.clean_name: var for var in swarm.vars.values()} - for var_clean_name in snap.swarmvar_names[swarm_id]: - var = current_vars.get(var_clean_name) - if var is None: - raise SnapshotInvalidatedError( - f"swarm variable {var_clean_name!r} from snapshot is not " - f"present on this swarm; restore requires the same variable set" - ) - saved = snap.backend.load_vector(_swarmvar_key(swarm_id, var_clean_name)) - current = np.asarray(var.data) - if saved.shape != current.shape: - raise SnapshotInvalidatedError( - f"swarm variable {var_clean_name!r} data shape changed: " - f"snapshot {saved.shape} vs current {current.shape}" - ) - current[...] = saved diff --git a/src/underworld3/swarm.py b/src/underworld3/swarm.py index 50e0f945..ff8dc593 100644 --- a/src/underworld3/swarm.py +++ b/src/underworld3/swarm.py @@ -2493,12 +2493,6 @@ def __init__(self, mesh, recycle_rate=0, verbose=False, clip_to_mesh=True): # Mesh version tracking for coordinate change detection self._mesh_version = mesh._mesh_version - # Snapshot/restore invalidation counter: bumped on every - # particle-population mutation (populate, add_particles_*, - # migrate, advection remesh). See - # docs/developer/design/in_memory_checkpoint_design.md. - self._population_generation = 0 - # Register this swarm with the mesh for coordinate change notifications mesh.register_swarm(self) @@ -3292,9 +3286,6 @@ def populate( offset = swarm_orig_size * i self._remeshed.data[offset::, 0] = i - # Snapshot invalidation: particle population just changed. - self._population_generation += 1 - return @timing.routine_timer_decorator @@ -3324,12 +3315,6 @@ def migrate( if self._migration_disabled: return - # Snapshot invalidation: migration may move or drop particles, - # changing per-rank population identity. Conservative bump even - # if the call is ultimately a no-op — over-bumping is safe, - # under-bumping risks silent corruption on restore. - self._population_generation += 1 - from time import time if delete_lost_points is None: @@ -3551,10 +3536,6 @@ def add_particles_with_coordinates(self, coordinatesArray) -> int: if hasattr(var, "_canonical_data"): var._canonical_data = None - # Snapshot invalidation: addNPoints + dm.migrate is a direct - # PETSc call path that does not go through Swarm.migrate. - self._population_generation += 1 - return npoints @timing.routine_timer_decorator @@ -3623,10 +3604,6 @@ def add_particles_with_global_coordinates( self.dm.finalizeFieldRegister() self.dm.addNPoints(npoints=npoints) - # Snapshot invalidation: population changed even if the caller - # opts out of the post-add migration (migrate=False). - self._population_generation += 1 - # Add new points with provided coords # Record the current rank (migration needs to know where we start from !) @@ -4493,9 +4470,6 @@ def advection( self.dm.addNPoints(num_remeshed_points) - # Snapshot invalidation: remesh just re-injected particles. - self._population_generation += 1 - ## cellid = self.dm.getField("DMSwarm_cellid") coords = self.dm.getField("DMSwarmPIC_coor").reshape((-1, self.dim)) rmsh = self.dm.getField("DMSwarm_remeshed") diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index edcfb3b0..71dfbb43 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -113,104 +113,3 @@ def test_snapshot_path_is_v1_1_scope(): with pytest.raises(NotImplementedError): model.snapshot(path="/tmp/should_not_be_written.h5") - - -# ----- Swarm coverage ----- - - -def _fresh_model_mesh_and_swarm(with_material=True): - """Create a fresh model + mesh + swarm. Swarm-variable creation must - happen before populate(), so we build everything in one place. - """ - import underworld3 as uw - - uw.reset_default_model() - model = uw.get_default_model() - mesh = uw.meshing.UnstructuredSimplexBox( - minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 - ) - swarm = uw.swarm.Swarm(mesh) - material = None - if with_material: - material = swarm.add_variable("material", 1, dtype=float) - swarm.populate(fill_param=2) - return uw, model, mesh, swarm, material - - -def test_swarm_positions_and_variable_roundtrip(): - """Snapshot, scramble swarm positions + svar, restore: both come back.""" - uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() - - coords = swarm._particle_coordinates.data - material.data[:, 0] = 0.5 * coords[:, 0] + coords[:, 1] - coords_pre = coords.copy() - material_pre = np.asarray(material.data).copy() - - snap = model.snapshot() - - coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) - coord_field[...] = -99.0 - swarm.dm.restoreField("DMSwarmPIC_coor") - material.data[...] = -99.0 - - model.restore(snap) - - assert np.allclose(swarm._particle_coordinates.data, coords_pre) - assert np.allclose(np.asarray(material.data), material_pre) - - -def test_swarm_population_generation_starts_at_zero_and_bumps(): - """Sanity-check the counter bumps on each mutation category.""" - uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm(with_material=False) - after_populate = swarm._population_generation - swarm.migrate(remove_sent_points=True) - after_migrate = swarm._population_generation - swarm.add_particles_with_coordinates(np.array([[0.5, 0.5]])) - after_add_local = swarm._population_generation - swarm.add_particles_with_global_coordinates(np.array([[0.25, 0.25]])) - after_add_global = swarm._population_generation - - assert after_populate >= 1 - assert after_migrate > after_populate - assert after_add_local > after_migrate - assert after_add_global > after_add_local - - -def test_swarm_migrate_invalidates_restore(): - """A migrate() call between snapshot and restore makes restore refuse.""" - from underworld3.checkpoint import SnapshotInvalidatedError - - uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() - - snap = model.snapshot() - swarm.migrate(remove_sent_points=True) - - with pytest.raises(SnapshotInvalidatedError, match="_population_generation"): - model.restore(snap) - - -def test_swarm_add_particles_invalidates_restore(): - """add_particles_with_coordinates between snapshot and restore raises.""" - from underworld3.checkpoint import SnapshotInvalidatedError - - uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() - - snap = model.snapshot() - swarm.add_particles_with_coordinates(np.array([[0.5, 0.5]])) - - with pytest.raises(SnapshotInvalidatedError, match="_population_generation"): - model.restore(snap) - - -def test_swarm_internal_variables_are_not_captured(): - """Internal DMSwarm_* variables stay out of the snapshot key list.""" - uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() - - snap = model.snapshot() - keys = snap.backend.list_vectors() - swarmvar_keys = [k for k in keys if k.startswith(f"swarm:{id(swarm)}:var:")] - captured_names = {k.split(":var:")[1].split(":data")[0] for k in swarmvar_keys} - - # User variable present, PETSc-internal ones absent. - assert "material" in captured_names - assert not any(n.startswith("DMSwarm") for n in captured_names) From 9399f87469d3d7940c30e1eff9ec4f7b11403630 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Mon, 11 May 2026 22:36:41 +1000 Subject: [PATCH 05/15] =?UTF-8?q?docs:=20correct=20snapshot=20design=20?= =?UTF-8?q?=E2=80=94=20rebuild-on-restore,=20mesh-adapt=20is=20v1.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The earlier design draft proposed Swarm._population_generation as an invalidation gate: counter mismatch between capture and restore would raise SnapshotInvalidatedError. That is wrong. The whole point of the toolkit is to undo intervening state changes — including particle motion, migration, and repopulation. Refusing on counter mismatch breaks the central use cases (RK staging, backtrack-on-instability, adaptive Δt retry, all of which migrate particles between capture and restore). Corrected swarm semantics: - Restore rebuilds the swarm's local population: clear current particles, re-add at captured per-rank coords via add_particles_with_coordinates(..., migrate=False), write captured per-variable data back into the new particles in order. - The _population_generation counter stays as informational metadata (logging, cache invalidation in other consumers, possible future fast-path optimisations), but it is not a restore gate. Mesh-adapt scope boundary reframed: - v1 keeps _mesh_version mismatch as a refusal, because the captured DOF arrays don't fit a different DM's section. - v1.2 will replace the refusal with a mesh-rebuild path on the same rebuild-on-restore principle: destroy the post-adapt DM, rebuild the pre-adapt one from captured topology + section, re-bind all MeshVariable / Swarm / solver wrappers. - v1 captures the topology / section info even though v1 restore ignores it, so the snapshot payload is forward-compatible with v1.2 without a schema bump. Architectural-work item 5 updated to match: snapshot captures per-rank particle coords + per-var arrays; restore clears + re-adds + writes; counter is informational, not a gate. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- .../design/in_memory_checkpoint_design.md | 102 +++++++++++------- 1 file changed, 66 insertions(+), 36 deletions(-) diff --git a/docs/developer/design/in_memory_checkpoint_design.md b/docs/developer/design/in_memory_checkpoint_design.md index 4c3a3676..10fec6ad 100644 --- a/docs/developer/design/in_memory_checkpoint_design.md +++ b/docs/developer/design/in_memory_checkpoint_design.md @@ -298,35 +298,52 @@ Tokens are plain Python / numpy — never PETSc Vec or DM handles. On-disk "tokens" are paths to single HDF5 files. Either way, the DM-lifecycle hazards identified in earlier work do not apply. -## Generation counter for swarm invalidation - -Snapshots are valid until the population that produced them changes. -The audit identified the mutation sites: - -| file | line(s) | call | +## Restore semantics for swarms — rebuild, do not refuse + +**Correction (2026-05-11, post-review).** An earlier draft of this +section proposed using a per-swarm `_population_generation` counter as +an *invalidation gate*: if the counter at restore differs from the +counter at capture, raise rather than restore. That design is wrong. +The whole point of the toolkit is to undo state changes — including +particle motion / migration / repopulation between capture and +restore. Refusing on counter mismatch breaks the central use cases (RK +staging, backtrack-on-instability, adaptive Δt retry — all of which +*will* migrate particles between capture and restore). + +**The correct semantics: restore rebuilds the swarm's particle +population from the snapshot.** Specifically: + +1. Clear the swarm's current local particles. +2. Re-add the captured per-rank coordinates via + `add_particles_with_coordinates(saved_local_coords, migrate=False)`. + The mesh partition is deterministic and unchanged within v1 scope + (mesh-version check still applies), so particles that were local at + capture are local at restore — no migration step needed. +3. Write the captured per-particle variable data back into the + newly-added particles, in their captured order. + +The `Swarm._population_generation` counter is still useful as +*informational metadata* — it can flag in logs / metadata what +happened between capture and restore, it can feed cache invalidation +in other consumers, it can power future optimisations (e.g., a fast +in-place restore when the counter happens to match). But it is **not** +a restore gate. + +Mutation sites where the counter is incremented (current line numbers +re-derived per audit; the design's correctness does not depend on the +exact set as long as we over-bump rather than under-bump): + +| file | site | call | |---|---|---| -| `swarm.py` | 3083, 3085, 3109 | `populate()` → `dm.addNPoints()` | -| `swarm.py` | 3365 | `add_particles_with_coordinates()` | -| `swarm.py` | 3449 | `add_particles_with_global_coordinates()` | -| `swarm.py` | 3223, 3382 | `migrate(remove_sent_points=True)` | -| `swarm.py` | 4298 | remesh/repopulate path | -| `discretisation_mesh.py` | 3090 | `Mesh.adapt()` (indirect via `_mesh_version`) | - -A single `Swarm._population_generation` counter, incremented at all -seven sites, is sufficient. `Mesh._mesh_version` already exists for the -mesh-side analogue. - -```python -def swarm.restore(token): - if swarm._population_generation != token.generation_at_snapshot: - raise SnapshotInvalidatedError( - "swarm population changed since snapshot — restore not safe") - # ... write positions back, write svar values back, migrate -``` +| `swarm.py` | end of `populate()` | covers internal `dm.addNPoints()` calls | +| `swarm.py` | `Swarm.migrate()` after `migration_disabled` early-exit | bumps unconditionally; conservative no-op safe | +| `swarm.py` | after `dm.migrate()` in `add_particles_with_coordinates()` | direct PETSc DM call, not via `Swarm.migrate` | +| `swarm.py` | after `addNPoints` in `add_particles_with_global_coordinates()` | catches `migrate=False` callers | +| `swarm.py` | `advection()` remesh path after re-injection `addNPoints` | recycle-mode reinitialisation | -Constraint documented as part of the contract: snapshots cannot survive -a population-change event. Consumers that take long-lived snapshots -across such events get a clear error rather than silent corruption. +`Mesh._mesh_version` is a separate counter on the mesh side. In v1 +the mesh-version mismatch *does* refuse restore — see the next +section. ## Architectural work required @@ -361,9 +378,14 @@ In rough dependency order: history (`parameters.py:145`), any solver convergence-tracking state (audit pending). Each retrofit is small; total bounded by the number of classes (probably under ten). -5. **Swarm `_population_generation` counter.** Bumped at the seven - identified sites; checked on restore (within-process only — see - item 6 for the cross-process semantics). +5. **Swarm rebuild on restore + informational + `_population_generation` counter.** Snapshot captures per-rank + particle coordinates and per-variable arrays. Restore clears the + current local population and re-adds the captured particles at + their captured coords (see "Restore semantics for swarms" section + above for the corrected design). The counter is bumped at every + identified mutation site for informational use; it is **not** an + invalidation gate. 6. **Schema versioning + migration registry.** Each `State` dataclass carries a `_schema_version` integer. A central registry maps `(class, version)` to migration functions that lift older State @@ -388,12 +410,20 @@ In rough dependency order: ## Scope boundaries (NOT in v1) -- **Mesh adaptation roundtrip.** A snapshot taken before a mesh - adaptation event cannot be restored after the adaptation — the DM - identity has changed. Documented as a contract limitation; the - generation-counter pattern detects and refuses for in-memory - restores. (For on-disk restore in a fresh process, the question - doesn't arise — the model is being initialised from the snapshot.) +- **Mesh adaptation roundtrip — scheduled for v1.2, not a permanent + limitation.** A snapshot taken before a `mesh.adapt()` event in v1 + refuses restore via the `_mesh_version` check, because the captured + DOF arrays are sized for the pre-adapt section and writing them + in-place into the post-adapt DM would corrupt the run. **v1.2 will + replace the refusal with a mesh-rebuild path**: capture enough + topology / section info to destroy the post-adapt DM and rebuild + the pre-adapt one, then write DOFs into the rebuilt DM. The + principle is the same as the swarm rebuild (capture-the-state, + rebuild-on-restore); the implementation is more invasive because + every MeshVariable / Swarm / solver holds references into the DM + and those wrappers need to re-bind. v1's snapshot **captures the + topology / section info even though v1 restore ignores it**, so the + payload is forward-compatible with v1.2 without a schema bump. - **Replacing the existing `write_timestep` path.** The selective per-variable on-disk path continues unchanged. The new full-state on-disk backend is additive and serves a different need (faithful From 1682c63c0cac49907a9f379620b52fdd82484f3e Mon Sep 17 00:00:00 2001 From: lmoresi Date: Mon, 11 May 2026 22:44:30 +1000 Subject: [PATCH 06/15] checkpoint: rebuild-on-restore for swarms, v1.2-ready interfaces (PR 2 redo) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the reverted counter-as-gate PR with rebuild-on-restore semantics for swarms. Restore now succeeds across the cases the earlier design wrongly refused (migrate, add_particles, repopulate between snapshot and restore) — these are precisely the cases the snapshot toolkit exists to enable (RK staging, backtrack on instability, adaptive Δt retry). Design changes since the reverted PR: - Swarm._population_generation stays, but is now purely informational: bumped at every population-mutation site for logging / debugging / downstream caches, but NOT consulted by restore. Restore rebuilds the local population from the snapshot regardless of intervening mutations. - Snapshot is keyed by stable name (mesh.name for meshes, f"swarm_{instance_number}" for swarms), not Python id(). Forward- compat for v1.1 cross-process restore and v1.2 mesh-rebuild after mesh.adapt() (where the wrapper survives but its DM is destroyed). - Restore logic moved off the snapshot module and onto wrapper methods: Mesh.apply_snapshot_payload and Swarm.apply_snapshot_payload. v1 implementations write back in place; v1.2's Mesh implementation can switch to rebuild-from-payload without touching snapshot.py. - Snapshot payloads include a reserved "topology": None slot on the mesh side, populated in v1.2 with section/DM-topology data sufficient to rebuild the DM. v1 leaves it None; the schema doesn't need to bump when v1.2 lands. Mesh restore (Mesh.apply_snapshot_payload at discretisation_mesh.py:2570): - Verify _mesh_version matches (v1 refusal; v1.2 will rebuild here). - Write coords via _deform_mesh, write per-MV gvec arrays + sync to local vec. Swarm restore (Swarm.apply_snapshot_payload at swarm.py:4084): - Drop every current local particle via dm.removePoint() (O(N) total, removes-from-end is O(1) per call). - addNPoints(n_saved), write coords directly to DMSwarmPIC_coor, set ranks. Deliberately bypasses add_particles_with_coordinates (which filters via points_in_domain and triggers migrate — both unnecessary here since saved coords were local at capture and the mesh hasn't changed). - Invalidate _canonical_data caches so subsequent var.data accesses re-resolve from PETSc. - Write captured per-variable data back in particle-order. Internal DMSwarm_* variables are filtered out at capture. Tests (11 total): - 6 mesh-only tests preserved. - 5 swarm tests, including the critical positive tests test_swarm_restore_after_migrate and test_swarm_restore_after_add_particles. Those are the cases the reverted PR wrongly raised on; they're now the central proof that the design works. Regression: 35 existing core tests pass unchanged. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/snapshot.py | 341 ++++++++++++------ .../discretisation/discretisation_mesh.py | 96 +++++ src/underworld3/swarm.py | 158 ++++++++ tests/test_0007_snapshot_inmemory.py | 132 +++++++ 4 files changed, 608 insertions(+), 119 deletions(-) diff --git a/src/underworld3/checkpoint/snapshot.py b/src/underworld3/checkpoint/snapshot.py index e4a20fb3..d9a78b57 100644 --- a/src/underworld3/checkpoint/snapshot.py +++ b/src/underworld3/checkpoint/snapshot.py @@ -7,10 +7,31 @@ processes (v1.1, on-disk backend) the model is initialised from the snapshot rather than restored to a previous state. +Forward-compatibility for v1.2 (mesh-adapt rebuild on restore) +--------------------------------------------------------------- +The snapshot module is structured so that v1.2 can replace the +``_mesh_version`` refusal with a true mesh-DM rebuild without touching +this module. Two principles support this: + +- **Capture by stable name, not by Python id.** Meshes are keyed by + ``mesh.name``, swarms by ``f"swarm_{instance_number}"``. Within a + single process this is overkill (object id would work); but it + trivialises v1.1 cross-process restore and v1.2 mesh-rebuild (where + the wrapper object survives but its DM is destroyed and recreated). + +- **Wrappers, not the snapshot module, decide how to apply a payload.** + ``Mesh.apply_snapshot_payload()`` and ``Swarm.apply_snapshot_payload()`` + receive a self-contained dict and decide what to do with it. v1 + implementations are in-place writes; v1.2's mesh implementation can + inspect the topology slot of the payload (left ``None`` by v1 + capture) and rebuild the DM if needed, without any change to the + capture / orchestration here. + This module implements the v1 scope: mesh coordinates and mesh-variable -DOFs. Swarm coverage, solver-internal Python state, on-disk backend, -schema versioning, and cross-process restore are scheduled for follow-up -PRs per the design note. +DOFs, plus swarm positions and user swarm-variable data with +rebuild-on-restore semantics. Solver-internal Python state, on-disk +backend, schema versioning, mesh-DM rebuild, and cross-process restore +are scheduled for follow-up PRs per the design note. """ from __future__ import annotations @@ -29,85 +50,128 @@ class SnapshotInvalidatedError(RuntimeError): """Raised when a snapshot can no longer be restored faithfully. - Triggers in v1: mesh ``_mesh_version`` differs from the snapshot - (mesh has been adapted; DM identity has changed), or a registered - mesh / mesh-variable named in the snapshot is no longer present on - the target :class:`underworld3.Model`. - - Future triggers (subsequent PRs): swarm population-generation - counter mismatch, on-disk schema version that has no migration - path. + Triggers in v1: + + - A captured mesh / swarm / variable name is no longer present on + the target :class:`underworld3.Model`. + - Mesh ``_mesh_version`` differs from the snapshot's captured + value. v1 treats this as fatal because the captured DOF arrays + are sized for the pre-adapt section. **v1.2 will replace this + refusal with a mesh-rebuild path** on the same principle as the + swarm rebuild — see the design note's mesh-adapt scope section. + + Notably **not** a trigger: swarm population mutation + (populate / migrate / add_particles / remesh) between capture and + restore. The swarm restore path *rebuilds* the local particle + population from the snapshot, so intervening mutations are exactly + what restore is for. The ``_population_generation`` counter on the + swarm is informational, not a restore gate. """ +def _swarm_stable_name(swarm) -> str: + """Per-process stable name for a swarm. Uses uw_object instance number.""" + return f"swarm_{swarm.instance_number}" + + @dataclass class Snapshot: """Unitary state token. Produced by :func:`snapshot`; consumed by :func:`restore`. Holds a - backend (where the bulk arrays live) plus per-Model bookkeeping — - which meshes were captured, which mesh variables were captured - under each mesh, and the mesh-version counters that gate - within-process restore. + backend (where the bulk arrays live) plus per-model bookkeeping — + which meshes and swarms were captured, in what order, with what + variable sets. Attributes ---------- backend - Where the captured arrays live. v1 always uses - :class:`InMemoryBackend`; v1.1 will add on-disk backends. + Where the captured arrays and small metadata live. schema_version Snapshot file-format version. Restore refuses on mismatch in v1; v1.1's migration registry will lift older versions to the current schema for on-disk restore only. - mesh_keys - Stable ordering of captured mesh identifiers (``id(mesh)``); - determines restore order. + mesh_names + Capture order of mesh names. ``mesh.name`` is the stable key. mesh_versions - Per-mesh ``_mesh_version`` at the moment of capture. Restore - compares against the current value; mismatch ⇒ - :class:`SnapshotInvalidatedError`. + Per-mesh ``_mesh_version`` at the moment of capture. v1 + compares strictly; v1.2 will rebuild on mismatch. meshvar_names - Mapping ``mesh_id → [var.clean_name, ...]`` — the mesh - variables captured for that mesh, in capture order. + Mapping ``mesh_name → [var clean_name, ...]``. + swarm_names + Capture order of swarm stable names + (``f"swarm_{instance_number}"``). + swarm_mesh_names + Mapping ``swarm_name → mesh_name`` so restore can verify the + swarm's parent mesh is still the captured one. + swarm_generations + Captured ``_population_generation`` per swarm — informational + metadata; *not* a restore gate. Useful for logs and debugging + ("this snapshot was taken at generation 7; the current swarm + is at 12"). + swarmvar_names + Mapping ``swarm_name → [user-var clean_name, ...]``. Internal + DMSwarm-prefixed variables are filtered out. metadata - User-visible bookkeeping (simulation time, step counter, free - text). Not load-bearing for restore correctness. + Free-form user/system metadata (simulation time, step counter, + ...). Not load-bearing for restore correctness. """ backend: CheckpointBackend schema_version: int = SNAPSHOT_SCHEMA_VERSION - mesh_keys: list[int] = field(default_factory=list) - mesh_versions: dict[int, int] = field(default_factory=dict) - meshvar_names: dict[int, list[str]] = field(default_factory=dict) + mesh_names: list[str] = field(default_factory=list) + mesh_versions: dict[str, int] = field(default_factory=dict) + meshvar_names: dict[str, list[str]] = field(default_factory=dict) + swarm_names: list[str] = field(default_factory=list) + swarm_mesh_names: dict[str, str] = field(default_factory=dict) + swarm_generations: dict[str, int] = field(default_factory=dict) + swarmvar_names: dict[str, list[str]] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) -def _mesh_coords_key(mesh_id: int) -> str: - return f"mesh:{mesh_id}:coords" +# ----- Backend key conventions ----- + +def _mesh_coords_key(mesh_name: str) -> str: + return f"mesh:{mesh_name}:coords" + + +def _meshvar_key(mesh_name: str, var_clean_name: str) -> str: + return f"mesh:{mesh_name}:var:{var_clean_name}:gvec" + + +def _swarm_coords_key(swarm_name: str) -> str: + return f"swarm:{swarm_name}:coords" -def _meshvar_key(mesh_id: int, var_clean_name: str) -> str: - return f"mesh:{mesh_id}:var:{var_clean_name}:gvec" +def _swarmvar_key(swarm_name: str, var_clean_name: str) -> str: + return f"swarm:{swarm_name}:var:{var_clean_name}:data" +def _is_internal_swarmvar(var_name: str) -> bool: + """Filter PETSc-managed internal swarm variables from user capture. + + ``DMSwarmPIC_coor`` is captured separately via the particle-coords + path. ``DMSwarm_X0`` and ``DMSwarm_remeshed`` carry recycle-related + bookkeeping that is regenerated on next solve and is out of scope + for v1 capture. + """ + return var_name.startswith("DMSwarm") + + +# ----- Capture (orchestration) ----- + def snapshot(model, *, path: Optional[str] = None) -> Snapshot: """Capture a unitary snapshot of the model's current state. - Parameters - ---------- - model - The :class:`underworld3.Model` whose registered meshes and - mesh variables should be captured. - path - Reserved for the v1.1 on-disk backend. Passing a non-``None`` - value raises :class:`NotImplementedError` in v1. - - Returns - ------- - Snapshot - Token suitable for passing to :func:`restore` on the same - ``model`` instance within the same process. v1 captures mesh - coordinates and mesh-variable global-vector DOF values. + Captures, in v1: each registered mesh's deformed coordinates and + every mesh-variable's global-vector DOFs; each registered swarm's + per-rank particle coordinates and user swarm-variable arrays. + + Pass ``path=...`` once the v1.1 on-disk backend lands. v1 raises + ``NotImplementedError``. + + See ``docs/developer/design/in_memory_checkpoint_design.md`` for + the design rationale and scope boundaries. """ if path is not None: raise NotImplementedError( @@ -116,44 +180,69 @@ def snapshot(model, *, path: Optional[str] = None) -> Snapshot: ) snap = Snapshot(backend=InMemoryBackend()) - for mesh_id, mesh in list(model._meshes.items()): - _capture_mesh(snap, mesh_id, mesh) + for mesh in list(model._meshes.values()): + _capture_mesh(snap, mesh) + for swarm in list(model._swarms.values()): + _capture_swarm(snap, swarm) return snap -def _capture_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: - if mesh_id in snap.mesh_keys: - return - snap.mesh_keys.append(mesh_id) - snap.mesh_versions[mesh_id] = int(getattr(mesh, "_mesh_version", 0)) +def _capture_mesh(snap: Snapshot, mesh) -> None: + payload = mesh.snapshot_payload() + name = payload["name"] + if name in snap.mesh_names: + raise RuntimeError( + f"duplicate mesh name {name!r} in snapshot capture; mesh names " + f"must be unique within a Model" + ) + snap.mesh_names.append(name) + snap.mesh_versions[name] = payload["mesh_version"] - coords = np.asarray(mesh.X.coords) - snap.backend.save_vector(_mesh_coords_key(mesh_id), coords) + snap.backend.save_vector(_mesh_coords_key(name), payload["coords"]) var_names: list[str] = [] - for var in mesh.vars.values(): - var._sync_lvec_to_gvec() - gvec_array = np.asarray(var._gvec.array) - snap.backend.save_vector(_meshvar_key(mesh_id, var.clean_name), gvec_array) - var_names.append(var.clean_name) - snap.meshvar_names[mesh_id] = var_names + for var_clean_name, gvec_array in payload["vars"].items(): + snap.backend.save_vector(_meshvar_key(name, var_clean_name), gvec_array) + var_names.append(var_clean_name) + snap.meshvar_names[name] = var_names + + +def _capture_swarm(snap: Snapshot, swarm) -> None: + payload = swarm.snapshot_payload() + name = payload["name"] + if name in snap.swarm_names: + raise RuntimeError( + f"duplicate swarm name {name!r} in snapshot capture" + ) + snap.swarm_names.append(name) + snap.swarm_mesh_names[name] = payload["mesh_name"] + snap.swarm_generations[name] = payload["population_generation"] + snap.backend.save_vector(_swarm_coords_key(name), payload["coords"]) + + var_names: list[str] = [] + for var_clean_name, data in payload["vars"].items(): + snap.backend.save_vector(_swarmvar_key(name, var_clean_name), data) + var_names.append(var_clean_name) + snap.swarmvar_names[name] = var_names + + +# ----- Restore (orchestration) ----- def restore(model, snap: Snapshot) -> None: """Restore the model from a snapshot. - Restore order (within-process; cross-process is v1.1): + Mesh restore in v1 writes captured coords + DOFs back in place. If + the mesh's ``_mesh_version`` has moved since capture, restore + raises :class:`SnapshotInvalidatedError` — this becomes a rebuild + path in v1.2. - 1. Mesh coordinates (via :meth:`Mesh._deform_mesh`, which rebuilds - coordinate caches and notifies registered callbacks). - 2. Mesh-variable DOFs (global vector written, then synced to local - vector via ``subdm.globalToLocal``). - 3. ``_mesh_version`` is verified equal to the capture value before - any write; mismatch raises :class:`SnapshotInvalidatedError`. - - Future PRs extend the order to: swarm positions + migrate → swarm - variable values → solver-internal Python state (DDt history, - parameter mutation history) → generation-counter validation last. + Swarm restore *rebuilds* the local particle population: clears + current particles, re-adds at captured coords, writes captured + per-variable data back in order. This is the rebuild-on-restore + semantics described in the design note's "Restore semantics for + swarms" section — restore is precisely *for* the case where + particles have moved / been added / been removed since capture. Parameters ---------- @@ -166,8 +255,9 @@ def restore(model, snap: Snapshot) -> None: Raises ------ SnapshotInvalidatedError - Mesh ``_mesh_version`` has changed since capture, or a - captured mesh / variable is no longer registered on the model. + Captured mesh / swarm / variable is no longer registered on + the model, or mesh ``_mesh_version`` has moved since capture + (mesh-adapt is v1.2 scope). TypeError ``snap`` is not a :class:`Snapshot`. """ @@ -181,53 +271,66 @@ def restore(model, snap: Snapshot) -> None: f"current {SNAPSHOT_SCHEMA_VERSION}; on-disk migration is v1.1" ) - for mesh_id in snap.mesh_keys: - mesh = model._meshes.get(mesh_id) + meshes_by_name = {m.name: m for m in model._meshes.values()} + swarms_by_name = {_swarm_stable_name(s): s for s in model._swarms.values()} + + for mesh_name in snap.mesh_names: + mesh = meshes_by_name.get(mesh_name) if mesh is None: raise SnapshotInvalidatedError( - f"mesh id {mesh_id} from snapshot is not registered on this " - f"Model; within-process restore requires the originating Model" + f"mesh {mesh_name!r} from snapshot is not registered on " + f"this Model; within-process restore requires the originating " + f"Model" ) - current_version = int(getattr(mesh, "_mesh_version", 0)) - captured_version = snap.mesh_versions[mesh_id] - if current_version != captured_version: - raise SnapshotInvalidatedError( - f"mesh._mesh_version moved from {captured_version} to " - f"{current_version} since snapshot — likely mesh.adapt() or " - f"deform_mesh() invalidated the DM identity" - ) - _restore_mesh(snap, mesh_id, mesh) - - -def _restore_mesh(snap: Snapshot, mesh_id: int, mesh) -> None: - coords = snap.backend.load_vector(_mesh_coords_key(mesh_id)) - expected_shape = np.asarray(mesh.X.coords).shape - if coords.shape != expected_shape: - raise SnapshotInvalidatedError( - f"mesh coordinate shape changed: snapshot {coords.shape} vs " - f"current {expected_shape}" - ) - mesh._deform_mesh(coords) + payload = _build_mesh_payload(snap, mesh_name) + mesh.apply_snapshot_payload(payload) - current_vars = {var.clean_name: var for var in mesh.vars.values()} - for var_clean_name in snap.meshvar_names[mesh_id]: - var = current_vars.get(var_clean_name) - if var is None: + for swarm_name in snap.swarm_names: + swarm = swarms_by_name.get(swarm_name) + if swarm is None: raise SnapshotInvalidatedError( - f"mesh variable {var_clean_name!r} from snapshot is not " - f"present on mesh; restore requires the same variable set" + f"swarm {swarm_name!r} from snapshot is not registered on " + f"this Model" ) - var._sync_lvec_to_gvec() # ensures _gvec exists with a current size - saved = snap.backend.load_vector(_meshvar_key(mesh_id, var_clean_name)) - current_shape = np.asarray(var._gvec.array).shape - if saved.shape != current_shape: + expected_mesh_name = snap.swarm_mesh_names[swarm_name] + if swarm.mesh.name != expected_mesh_name: raise SnapshotInvalidatedError( - f"variable {var_clean_name!r} gvec shape changed: snapshot " - f"{saved.shape} vs current {current_shape}" + f"swarm {swarm_name!r} parent mesh changed from " + f"{expected_mesh_name!r} to {swarm.mesh.name!r} since " + f"snapshot" + ) + payload = _build_swarm_payload(snap, swarm_name) + swarm.apply_snapshot_payload(payload) + + +def _build_mesh_payload(snap: Snapshot, mesh_name: str) -> dict: + return { + "name": mesh_name, + "captured_mesh_version": snap.mesh_versions[mesh_name], + "coords": snap.backend.load_vector(_mesh_coords_key(mesh_name)), + # Topology is None in v1; v1.2 mesh-rebuild path will populate + # this slot (e.g., section view data) without bumping the + # schema version, because v1 reads ignore the key. + "topology": None, + "vars": { + var_clean_name: snap.backend.load_vector( + _meshvar_key(mesh_name, var_clean_name) + ) + for var_clean_name in snap.meshvar_names[mesh_name] + }, + } + + +def _build_swarm_payload(snap: Snapshot, swarm_name: str) -> dict: + return { + "name": swarm_name, + "mesh_name": snap.swarm_mesh_names[swarm_name], + "captured_population_generation": snap.swarm_generations[swarm_name], + "coords": snap.backend.load_vector(_swarm_coords_key(swarm_name)), + "vars": { + var_clean_name: snap.backend.load_vector( + _swarmvar_key(swarm_name, var_clean_name) ) - var._gvec.array[...] = saved - iset, subdm = mesh.dm.createSubDM(var.field_id) - subdm.globalToLocal(var._gvec, var._lvec, addv=False) - iset.destroy() - subdm.destroy() - mesh._stale_lvec = True + for var_clean_name in snap.swarmvar_names[swarm_name] + }, + } diff --git a/src/underworld3/discretisation/discretisation_mesh.py b/src/underworld3/discretisation/discretisation_mesh.py index bf3e4075..d6bb95c6 100644 --- a/src/underworld3/discretisation/discretisation_mesh.py +++ b/src/underworld3/discretisation/discretisation_mesh.py @@ -2567,6 +2567,102 @@ def write_checkpoint( uw.mpi.barrier() # should not be required viewer.destroy() + # ----- Unitary snapshot / restore ----- + # + # See ``src/underworld3/checkpoint/snapshot.py`` and + # ``docs/developer/design/in_memory_checkpoint_design.md``. v1 + # captures deformed coords + per-MV global-vector DOFs; v1.2 will + # add topology / section capture so the DM can be rebuilt on + # restore after ``mesh.adapt()``. + + def snapshot_payload(self) -> dict: + """Return a self-contained dict describing this mesh's state. + + The returned dict is consumed by + :mod:`underworld3.checkpoint.snapshot` capture. Keys: + + - ``name``: stable string identifier for the mesh. + - ``mesh_version``: current ``_mesh_version`` integer. + - ``coords``: deformed mesh coordinates (numpy array). + - ``vars``: ``{var.clean_name: gvec_array.copy()}`` for every + mesh variable on this mesh. + + v1.2 will additionally populate a ``topology`` key with + section / DM-topology data sufficient to rebuild the DM on + restore. + """ + coords = numpy.asarray(self.X.coords).copy() + var_arrays: Dict[str, numpy.ndarray] = {} + for var in self.vars.values(): + var._sync_lvec_to_gvec() + var_arrays[var.clean_name] = numpy.asarray(var._gvec.array).copy() + return { + "name": self.name, + "mesh_version": int(getattr(self, "_mesh_version", 0)), + "coords": coords, + "vars": var_arrays, + } + + def apply_snapshot_payload(self, payload: dict) -> None: + """Restore this mesh from a payload produced by :meth:`snapshot_payload`. + + v1 implementation writes coordinates and per-variable DOFs + back in place. The captured DOF arrays must match the current + section, which means ``_mesh_version`` must equal the captured + value — mesh-adapt during the interval would have resized the + section and is detected as a v1 refusal here. + + v1.2 will replace the ``_mesh_version`` refusal with a + rebuild-from-payload path: destroy the current DM, rebuild + from ``payload["topology"]``, allocate vectors, write DOFs, + and re-bind MeshVariable / Swarm wrappers. The interface stays + the same; only this method's body changes. + """ + from underworld3.checkpoint.snapshot import SnapshotInvalidatedError + + current_version = int(getattr(self, "_mesh_version", 0)) + captured_version = int(payload["captured_mesh_version"]) + if current_version != captured_version: + raise SnapshotInvalidatedError( + f"mesh {self.name!r}: _mesh_version moved from " + f"{captured_version} to {current_version} since snapshot. " + f"mesh.adapt() rebuild on restore is scheduled for v1.2; " + f"v1 refuses rather than corrupt the DOF arrays" + ) + + coords = numpy.asarray(payload["coords"]) + expected_shape = numpy.asarray(self.X.coords).shape + if coords.shape != expected_shape: + raise SnapshotInvalidatedError( + f"mesh {self.name!r}: coordinate shape changed " + f"({coords.shape} vs current {expected_shape}); programming " + f"error since _mesh_version matched" + ) + self._deform_mesh(coords) + + current_vars = {var.clean_name: var for var in self.vars.values()} + for var_clean_name, saved_array in payload["vars"].items(): + var = current_vars.get(var_clean_name) + if var is None: + raise SnapshotInvalidatedError( + f"mesh {self.name!r}: variable {var_clean_name!r} " + f"from snapshot is no longer present" + ) + var._sync_lvec_to_gvec() + current_shape = numpy.asarray(var._gvec.array).shape + if saved_array.shape != current_shape: + raise SnapshotInvalidatedError( + f"mesh {self.name!r}: variable {var_clean_name!r} gvec " + f"shape changed ({saved_array.shape} vs current " + f"{current_shape})" + ) + var._gvec.array[...] = saved_array + iset, subdm = self.dm.createSubDM(var.field_id) + subdm.globalToLocal(var._gvec, var._lvec, addv=False) + iset.destroy() + subdm.destroy() + self._stale_lvec = True + @timing.routine_timer_decorator def write(self, filename: str, index: Optional[int] = None): """ diff --git a/src/underworld3/swarm.py b/src/underworld3/swarm.py index ff8dc593..a0aeef7a 100644 --- a/src/underworld3/swarm.py +++ b/src/underworld3/swarm.py @@ -2493,6 +2493,15 @@ def __init__(self, mesh, recycle_rate=0, verbose=False, clip_to_mesh=True): # Mesh version tracking for coordinate change detection self._mesh_version = mesh._mesh_version + # Informational counter incremented at every particle-population + # mutation site (populate, migrate, add_particles_*, advection + # remesh). NOT a snapshot-restore invalidation gate — restore + # rebuilds the local population from the snapshot regardless of + # what happened in between. Useful for logging, debugging, and + # any cache that wants to know "did the population change?" + # See docs/developer/design/in_memory_checkpoint_design.md. + self._population_generation = 0 + # Register this swarm with the mesh for coordinate change notifications mesh.register_swarm(self) @@ -3286,6 +3295,9 @@ def populate( offset = swarm_orig_size * i self._remeshed.data[offset::, 0] = i + # Informational: particle population just changed. + self._population_generation += 1 + return @timing.routine_timer_decorator @@ -3315,6 +3327,11 @@ def migrate( if self._migration_disabled: return + # Informational: migration may move or drop particles. Bump + # unconditionally; restore is not gated on this counter so a + # conservative no-op bump is harmless. + self._population_generation += 1 + from time import time if delete_lost_points is None: @@ -3536,6 +3553,10 @@ def add_particles_with_coordinates(self, coordinatesArray) -> int: if hasattr(var, "_canonical_data"): var._canonical_data = None + # Informational: addNPoints + direct dm.migrate path doesn't go + # through Swarm.migrate, so bump explicitly. + self._population_generation += 1 + return npoints @timing.routine_timer_decorator @@ -3604,6 +3625,10 @@ def add_particles_with_global_coordinates( self.dm.finalizeFieldRegister() self.dm.addNPoints(npoints=npoints) + # Informational: population changed even if migrate=False is + # passed (in which case Swarm.migrate's bump wouldn't fire). + self._population_generation += 1 + # Add new points with provided coords # Record the current rank (migration needs to know where we start from !) @@ -4056,6 +4081,136 @@ def vars(self): """ return self._vars + # ----- Unitary snapshot / restore ----- + # + # See ``src/underworld3/checkpoint/snapshot.py`` and + # ``docs/developer/design/in_memory_checkpoint_design.md``. Capture + # records the per-rank particle layout and user-variable arrays. + # Restore rebuilds the local population from the snapshot rather + # than refusing on counter mismatch — restore is precisely for the + # case where particles have moved / migrated / been repopulated. + + def _snapshot_stable_name(self) -> str: + """Per-process stable name. ``instance_number`` comes from uw_object.""" + return f"swarm_{self.instance_number}" + + def snapshot_payload(self) -> dict: + """Return a self-contained dict describing this swarm's state. + + Captured: per-rank particle coordinates (from + ``DMSwarmPIC_coor``) and every user swarm-variable's data + array. PETSc-internal variables (``DMSwarmPIC_coor``, + ``DMSwarm_X0``, ``DMSwarm_remeshed``) are excluded — their + contents either come from the captured coords or are + regenerated on the next solve. + """ + coord_field = self.dm.getField("DMSwarmPIC_coor").reshape( + (-1, self.dim) + ) + coords = np.asarray(coord_field).copy() + self.dm.restoreField("DMSwarmPIC_coor") + + var_arrays: dict = {} + for var in list(self._vars.values()): + if var.name.startswith("DMSwarm"): + continue + var_arrays[var.clean_name] = np.asarray(var.data).copy() + + return { + "name": self._snapshot_stable_name(), + "mesh_name": self.mesh.name, + "population_generation": int(self._population_generation), + "coords": coords, + "vars": var_arrays, + } + + def apply_snapshot_payload(self, payload: dict) -> None: + """Rebuild this swarm's local particle population from a payload. + + Algorithm: + + 1. Drop every current local particle (``dm.removePoint`` from + the end is O(1) per call, O(N) total). + 2. Add the captured-rank's particles back via the raw PETSc + primitives — ``addNPoints`` then writing the coord field + directly. We deliberately bypass + :meth:`add_particles_with_coordinates` because that method + filters via ``points_in_domain`` (slow) and triggers + ``dm.migrate`` (unnecessary — saved coords were already + local at capture time, and the mesh hasn't changed). + 3. Write captured per-variable data back. The local particle + count matches the captured count because we just put the + same particles back in the same order. + + This bumps ``_population_generation`` once (from the addNPoints + step in restore), which is correct: the population *did* just + change. Downstream consumers that care can compare against the + captured value in ``payload['captured_population_generation']``. + """ + from underworld3.checkpoint.snapshot import SnapshotInvalidatedError + + saved_coords = np.asarray(payload["coords"]) + + # Step 1: clear local population. removePoint() removes the last + # particle, so this is O(N) total. + while self.dm.getLocalSize() > 0: + self.dm.removePoint() + + # Step 2: re-add. add raw points, write coords + ranks directly. + n_saved = int(saved_coords.shape[0]) + if n_saved > 0: + self.dm.finalizeFieldRegister() + self.dm.addNPoints(npoints=n_saved) + + coord_field = self.dm.getField("DMSwarmPIC_coor").reshape( + (-1, self.dim) + ) + if coord_field.shape != saved_coords.shape: + self.dm.restoreField("DMSwarmPIC_coor") + raise SnapshotInvalidatedError( + f"swarm {self._snapshot_stable_name()!r}: after " + f"addNPoints({n_saved}) the coord field has shape " + f"{coord_field.shape}, expected {saved_coords.shape}" + ) + coord_field[...] = saved_coords + self.dm.restoreField("DMSwarmPIC_coor") + + rank_field = self.dm.getField("DMSwarm_rank") + rank_field[...] = uw.mpi.rank + self.dm.restoreField("DMSwarm_rank") + + # Invalidate canonical-data caches — the underlying arrays + # have been reallocated by the addNPoints path. + if hasattr(self._particle_coordinates, "_canonical_data"): + self._particle_coordinates._canonical_data = None + for var in self._vars.values(): + if hasattr(var, "_canonical_data"): + var._canonical_data = None + + # The clear+re-add path bumped _population_generation already + # (we don't bump on removePoint, but addNPoints isn't bumped + # either — these are raw PETSc calls). For consistency with + # other mutation paths, bump explicitly here. + self._population_generation += 1 + + # Step 3: write captured per-variable data. + current_vars = {var.clean_name: var for var in self._vars.values()} + for var_clean_name, saved in payload["vars"].items(): + var = current_vars.get(var_clean_name) + if var is None: + raise SnapshotInvalidatedError( + f"swarm {self._snapshot_stable_name()!r}: variable " + f"{var_clean_name!r} from snapshot is not present" + ) + current = np.asarray(var.data) + if current.shape != saved.shape: + raise SnapshotInvalidatedError( + f"swarm {self._snapshot_stable_name()!r}: variable " + f"{var_clean_name!r} data shape mismatch — current " + f"{current.shape} vs snapshot {saved.shape}" + ) + current[...] = saved + def _legacy_access(self, *writeable_vars: SwarmVariable): """ This context manager makes the underlying swarm variables data available to @@ -4470,6 +4625,9 @@ def advection( self.dm.addNPoints(num_remeshed_points) + # Informational: remesh just re-injected particles. + self._population_generation += 1 + ## cellid = self.dm.getField("DMSwarm_cellid") coords = self.dm.getField("DMSwarmPIC_coor").reshape((-1, self.dim)) rmsh = self.dm.getField("DMSwarm_remeshed") diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index 71dfbb43..fbf69167 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -113,3 +113,135 @@ def test_snapshot_path_is_v1_1_scope(): with pytest.raises(NotImplementedError): model.snapshot(path="/tmp/should_not_be_written.h5") + + +# ----- Swarm coverage (rebuild-on-restore semantics) ----- + + +def _fresh_model_mesh_and_swarm(with_material=True): + """Fresh model + mesh + populated swarm. svar must be added pre-populate.""" + import underworld3 as uw + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + swarm = uw.swarm.Swarm(mesh) + material = None + if with_material: + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + return uw, model, mesh, swarm, material + + +def _swarm_coords(swarm): + """Return a copy of the current per-rank particle coords.""" + field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) + out = np.asarray(field).copy() + swarm.dm.restoreField("DMSwarmPIC_coor") + return out + + +def test_swarm_no_change_roundtrip(): + """Trivial case: snapshot, scribble, restore — both coords and svar recovered.""" + uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() + material.data[:, 0] = 0.5 * _swarm_coords(swarm)[:, 0] + coords_pre = _swarm_coords(swarm) + material_pre = np.asarray(material.data).copy() + + snap = model.snapshot() + + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) + coord_field[...] = -99.0 + swarm.dm.restoreField("DMSwarmPIC_coor") + material.data[...] = -99.0 + + model.restore(snap) + + assert np.allclose(_swarm_coords(swarm), coords_pre) + assert np.allclose(np.asarray(material.data), material_pre) + + +def test_swarm_restore_after_migrate(): + """Migrate between snapshot and restore: restore puts the swarm back. This + is the case my earlier counter-as-gate design wrongly refused.""" + uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() + material.data[:, 0] = 1.0 + coords_pre = _swarm_coords(swarm) + material_pre = np.asarray(material.data).copy() + pop_gen_pre = swarm._population_generation + + snap = model.snapshot() + + # Mutate: migrate() will bump the counter regardless of whether + # particles actually moved. Restore must succeed anyway. + swarm.migrate(remove_sent_points=True) + assert swarm._population_generation > pop_gen_pre, "migrate didn't bump counter" + + model.restore(snap) + + assert np.allclose(_swarm_coords(swarm), coords_pre), ( + "restore did not recover particle coords across a migrate event" + ) + assert np.allclose(np.asarray(material.data), material_pre), ( + "restore did not recover svar data across a migrate event" + ) + + +def test_swarm_restore_after_add_particles(): + """Particles added between snapshot and restore: restore *removes* them.""" + uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() + material.data[:, 0] = 2.0 + coords_pre = _swarm_coords(swarm) + material_pre = np.asarray(material.data).copy() + npre = swarm.dm.getLocalSize() + + snap = model.snapshot() + + swarm.add_particles_with_coordinates( + np.array([[0.5, 0.5], [0.25, 0.75]]) + ) + assert swarm.dm.getLocalSize() != npre, "add_particles didn't grow swarm" + + model.restore(snap) + + assert swarm.dm.getLocalSize() == npre, ( + "restore did not roll back to the captured particle count" + ) + assert np.allclose(_swarm_coords(swarm), coords_pre) + assert np.allclose(np.asarray(material.data), material_pre) + + +def test_swarm_population_generation_is_informational_not_a_gate(): + """The counter ticks up across mutations but does NOT block restore.""" + uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() + gen_at_capture = swarm._population_generation + + snap = model.snapshot() + + swarm.migrate(remove_sent_points=True) + swarm.add_particles_with_coordinates(np.array([[0.5, 0.5]])) + gen_during = swarm._population_generation + assert gen_during > gen_at_capture + + # Restore is expected to *succeed*, not raise. + model.restore(snap) + + # And the counter has moved on from where it was at capture, + # because restore itself counts as a population change. + assert swarm._population_generation > gen_at_capture + + +def test_swarm_internal_variables_are_not_captured(): + """Internal DMSwarm_* variables stay out of the snapshot key list.""" + uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() + + snap = model.snapshot() + keys = snap.backend.list_vectors() + swarm_name = swarm._snapshot_stable_name() + svar_keys = [k for k in keys if k.startswith(f"swarm:{swarm_name}:var:")] + captured_names = {k.split(":var:")[1].split(":data")[0] for k in svar_keys} + + assert "material" in captured_names + assert not any(n.startswith("DMSwarm") for n in captured_names) From 5ef1a119f873e200ce614f3c013cc20c14e33d4f Mon Sep 17 00:00:00 2001 From: lmoresi Date: Mon, 11 May 2026 23:04:24 +1000 Subject: [PATCH 07/15] checkpoint: Snapshottable contract + DDt Symbolic retrofit (PR 3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces the state-as-dataclass serialisation contract from the design note and applies it to the canonical DDt flavor (Symbolic). PR 4 will mechanically extend the same pattern to Eulerian, SemiLagrangian, Lagrangian, and Lagrangian_Swarm — they share the same dt_history / history_initialised / n_solves_completed / dt core; the variation is purely in how psi_star is bound. New infrastructure: - src/underworld3/checkpoint/state.py: - SnapshottableState dataclass base with _schema_version field (load-bearing for v1.1 on-disk migration; checked for strict equality in v1 since capture and restore are same-process). - Snapshottable runtime_checkable protocol — anything with a .state attribute returning a SnapshottableState. Drives discovery in Model.snapshot(). - Model._state_bearers (WeakSet) + Model._register_state_bearer(): state-bearing helpers self-register on construction without pinning their lifetime. DDt Symbolic retrofit (option (B)-style adapter per design note): - DDtSymbolicState dataclass with dt_history, history_initialised, n_solves_completed, dt, psi_star. - Symbolic gets a `state` property (builds the dataclass from the existing private attrs on read) and a `state.setter` (unpacks, validates schema version + dt_history length, writes attrs back, re-derives BDF/AM coefficient values so downstream reads see the restored state immediately rather than waiting for the next update_pre_solve). - Symbolic.__init__ auto-registers with the default model. Snapshot/restore wiring: - Snapshot.state_bearers: list of (stable_key, state_dataclass) with stable_key = f"{type(obj).__name__}_{obj.instance_number}". - snapshot() iterates Model._state_bearers, deepcopies obj.state, stores the copy. Deepcopy isolates the snapshot from later mutation of the live state-bearer. - restore() matches captured states to current state-bearers by stable_key, deepcopies, writes via obj.state setter. Missing state-bearer → SnapshotInvalidatedError. Drive-by fix in Mesh.snapshot_payload (caught by the DDt tests): mesh variables with _gvec=None (lazy allocation: var created but never written to) are now skipped during capture rather than crashing on var._gvec.array. Restore correspondingly only touches variables present in payload["vars"], so an unallocated-at-capture variable is left in its current state. Tests (6 new on top of the 11 mesh+swarm ones): - DDt auto-registers in Model._state_bearers - .state returns a SnapshottableState with correct schema version - mid-trajectory snapshot+restore recovers dt_history, history_initialised, n_solves_completed, dt - wrong _schema_version on apply raises ValueError - dt_history length mismatch on apply raises ValueError - snapshot is a deep copy: scribbling live DDt internals doesn't leak into the captured state-bearer payload 17/17 snapshot tests pass; 41 existing core + DDt tests still green. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/__init__.py | 3 + src/underworld3/checkpoint/snapshot.py | 50 ++++++++ src/underworld3/checkpoint/state.py | 77 ++++++++++++ .../discretisation/discretisation_mesh.py | 5 + src/underworld3/model.py | 19 +++ src/underworld3/systems/ddt.py | 117 +++++++++++++++++- tests/test_0007_snapshot_inmemory.py | 109 ++++++++++++++++ 7 files changed, 379 insertions(+), 1 deletion(-) create mode 100644 src/underworld3/checkpoint/state.py diff --git a/src/underworld3/checkpoint/__init__.py b/src/underworld3/checkpoint/__init__.py index d211adaf..fcde6412 100644 --- a/src/underworld3/checkpoint/__init__.py +++ b/src/underworld3/checkpoint/__init__.py @@ -25,6 +25,7 @@ snapshot, restore, ) +from .state import Snapshottable, SnapshottableState __all__ = [ "CheckpointBackend", @@ -34,4 +35,6 @@ "SnapshotInvalidatedError", "snapshot", "restore", + "Snapshottable", + "SnapshottableState", ] diff --git a/src/underworld3/checkpoint/snapshot.py b/src/underworld3/checkpoint/snapshot.py index d9a78b57..eff45bac 100644 --- a/src/underworld3/checkpoint/snapshot.py +++ b/src/underworld3/checkpoint/snapshot.py @@ -36,12 +36,14 @@ from __future__ import annotations +import copy from dataclasses import dataclass, field from typing import Any, Optional import numpy as np from .backend import CheckpointBackend, InMemoryBackend +from .state import SnapshottableState SNAPSHOT_SCHEMA_VERSION = 1 @@ -126,6 +128,12 @@ class Snapshot: swarm_mesh_names: dict[str, str] = field(default_factory=dict) swarm_generations: dict[str, int] = field(default_factory=dict) swarmvar_names: dict[str, list[str]] = field(default_factory=dict) + # State-bearer captures: list of (stable_key, state_dataclass). + # stable_key is f"{type(obj).__name__}_{obj.instance_number}", matched + # at restore against the same key derived from currently-registered + # state-bearers. List preserves capture order — informational only, + # since lookup is by key. + state_bearers: list = field(default_factory=list) metadata: dict[str, Any] = field(default_factory=dict) @@ -184,6 +192,8 @@ def snapshot(model, *, path: Optional[str] = None) -> Snapshot: _capture_mesh(snap, mesh) for swarm in list(model._swarms.values()): _capture_swarm(snap, swarm) + for obj in list(model._state_bearers): + _capture_state_bearer(snap, obj) return snap @@ -207,6 +217,32 @@ def _capture_mesh(snap: Snapshot, mesh) -> None: snap.meshvar_names[name] = var_names +def _state_bearer_key(obj) -> str: + """Stable per-process key for a Snapshottable. ``instance_number`` + comes from ``uw_object`` and is unique across the run.""" + return f"{type(obj).__name__}_{obj.instance_number}" + + +def _capture_state_bearer(snap: Snapshot, obj) -> None: + """Pull ``obj.state`` and store a deep copy. + + Deep copy ensures later mutations on the live state-bearer don't + leak into the captured token. The dataclass itself is the storage + here (no separate backend.save_vector call) because state + dataclasses are small Python objects, not bulk numerical arrays. + v1.1's on-disk backend will route the dataclass through the + backend (HDF5 attrs/groups); v1 holds them in the in-memory Snapshot + directly. + """ + state = obj.state + if not isinstance(state, SnapshottableState): + raise TypeError( + f"{type(obj).__name__}.state must be a SnapshottableState, " + f"got {type(state).__name__}" + ) + snap.state_bearers.append((_state_bearer_key(obj), copy.deepcopy(state))) + + def _capture_swarm(snap: Snapshot, swarm) -> None: payload = swarm.snapshot_payload() name = payload["name"] @@ -302,6 +338,20 @@ def restore(model, snap: Snapshot) -> None: payload = _build_swarm_payload(snap, swarm_name) swarm.apply_snapshot_payload(payload) + if snap.state_bearers: + bearers_by_key = { + _state_bearer_key(o): o for o in list(model._state_bearers) + } + for key, captured_state in snap.state_bearers: + obj = bearers_by_key.get(key) + if obj is None: + raise SnapshotInvalidatedError( + f"state-bearer {key!r} from snapshot is not registered " + f"on this Model; restore requires the originating " + f"Model" + ) + obj.state = copy.deepcopy(captured_state) + def _build_mesh_payload(snap: Snapshot, mesh_name: str) -> dict: return { diff --git a/src/underworld3/checkpoint/state.py b/src/underworld3/checkpoint/state.py new file mode 100644 index 00000000..646b8b2a --- /dev/null +++ b/src/underworld3/checkpoint/state.py @@ -0,0 +1,77 @@ +"""State-as-dataclass contract for snapshot capture/restore. + +Per the design note's ``General serialisation contract for solver-internal +state'' section, the chosen pattern for new solver-internal classes is +**option (C): state as a first-class dataclass**. A class declares a +``state`` attribute that is a dataclass; the snapshot mechanism reads it +automatically. + +For existing classes being retrofitted (DDt today, parameter mutation +history next, any solver convergence-tracking state after that), the +design recommends **option (B)-style adapters**: a ``state`` property +that *derives* a dataclass from the existing private attributes, and a +matching setter that writes the dataclass values back. The dataclass is +not the authoritative store — the private attrs are — but the snapshot +mechanism only sees the dataclass, so the snapshot/restore contract is +the same as for option (C). + +This module defines: + +- :class:`SnapshottableState` — the dataclass base every state object + should inherit / mimic. Carries the ``_schema_version`` field that + drives the v1.1 on-disk migration registry. +- :class:`Snapshottable` — a structural protocol; an object is + Snapshottable if it exposes a ``state`` attribute that is a + :class:`SnapshottableState`. Used by ``Model.snapshot()`` discovery. + +The actual State dataclasses for specific classes (``DDtState``, +``ParameterRegistryState``) live next to those classes in the systems +/ utilities packages. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Protocol, runtime_checkable + + +@dataclass +class SnapshottableState: + """Base for every per-class State dataclass. + + Subclasses add their own fields. The single mandatory field is + ``_schema_version`` — an integer that v1.1's on-disk migration + registry uses to lift older snapshots to the current schema. In + v1 (in-memory only) the version is checked for strict equality; + any mismatch is a programming error since capture and restore + happen in the same process. + """ + + _schema_version: int = 1 + + +@runtime_checkable +class Snapshottable(Protocol): + """Structural protocol for state-bearing objects. + + An object is Snapshottable if:: + + obj.state # readable + obj.state = obj.state # writable + isinstance(obj.state, SnapshottableState) + + The snapshot mechanism uses :attr:`state` to capture and restore. + Implementations choose between: + + - **Option (C), authoritative dataclass.** ``state`` is a stored + attribute holding the dataclass; every mutation site on the class + writes ``self.state.`` directly. Best for new code. + + - **Option (B), derived dataclass.** ``state`` is a property that + builds the dataclass from existing private attrs on each read, + and the setter writes attrs back. Best for retrofits of existing + classes (DDt, ParameterRegistry). + """ + + @property + def state(self) -> SnapshottableState: ... diff --git a/src/underworld3/discretisation/discretisation_mesh.py b/src/underworld3/discretisation/discretisation_mesh.py index d6bb95c6..515c7511 100644 --- a/src/underworld3/discretisation/discretisation_mesh.py +++ b/src/underworld3/discretisation/discretisation_mesh.py @@ -2595,6 +2595,11 @@ def snapshot_payload(self) -> dict: var_arrays: Dict[str, numpy.ndarray] = {} for var in self.vars.values(): var._sync_lvec_to_gvec() + # Variables created but never touched have _gvec=None (lazy + # allocation in MeshVariable). They carry no data so they + # contribute nothing to the snapshot — skip cleanly. + if var._gvec is None: + continue var_arrays[var.clean_name] = numpy.asarray(var._gvec.array).copy() return { "name": self.name, diff --git a/src/underworld3/model.py b/src/underworld3/model.py index ca884569..248aa745 100644 --- a/src/underworld3/model.py +++ b/src/underworld3/model.py @@ -126,6 +126,14 @@ class Model(PintNativeModelMixin, BaseModel): _variables: Dict[str, Any] = PrivateAttr(default_factory=dict) _solvers: Dict[str, Any] = PrivateAttr(default_factory=dict) + # State-bearing objects (DDt instances, parameter-history holders, + # any helper that exposes a .state dataclass per the Snapshottable + # protocol — see src/underworld3/checkpoint/snapshot.py). WeakSet so + # the registry does not pin objects past their natural lifetime. + # Snapshot capture and restore iterate this registry; consumers + # other than checkpoint may also walk it. + _state_bearers: Any = PrivateAttr(default_factory=weakref.WeakSet) + def __init__(self, name: Optional[str] = None, **kwargs): """ Initialize a new Model instance. @@ -565,6 +573,17 @@ def get_solver(self, name: str): """Get a solver by name from the model registry""" return self._solvers.get(name) + def _register_state_bearer(self, obj) -> None: + """Register a Snapshottable object with this model. + + Called by helper classes (DDt, parameter-history holders, ...) + in their ``__init__``. ``obj`` should expose a ``.state`` + attribute returning a dataclass with ``_schema_version``. + Membership is via WeakSet, so registration does not extend + ``obj``'s lifetime. + """ + self._state_bearers.add(obj) + def snapshot(self, *, path: Optional[str] = None): """Capture a unitary in-memory snapshot of the model's state. diff --git a/src/underworld3/systems/ddt.py b/src/underworld3/systems/ddt.py index 39bf5692..d21ebebd 100644 --- a/src/underworld3/systems/ddt.py +++ b/src/underworld3/systems/ddt.py @@ -54,17 +54,69 @@ from sympy import sympify import numpy as np -from typing import Optional, Callable, Union +from dataclasses import dataclass, field +from typing import Any, Optional, Callable, Union import underworld3 as uw from underworld3 import VarType import underworld3.timing as timing from underworld3.utilities._api_tools import uw_object +from underworld3.checkpoint.state import SnapshottableState from petsc4py import PETSc +# ----- Snapshot state dataclasses for DDt flavors ----- +# +# Per the design note's "General serialisation contract" section, each +# DDt class exposes a derived State dataclass via ``.state``. The +# private ``_dt_history`` / ``_history_initialised`` / etc. remain the +# authoritative store; the dataclass is built on read and unpacked on +# write. See ``src/underworld3/checkpoint/state.py``. +# +# PR 3 retrofits the Symbolic class. PR 4 will extend the pattern to +# Eulerian, SemiLagrangian, Lagrangian, and Lagrangian_Swarm — each has +# the same dt_history / history_initialised / n_solves_completed / dt +# core plus a flavor-specific psi_star shape. + + +@dataclass +class DDtSymbolicState(SnapshottableState): + """Snapshot of a :class:`Symbolic` DDt instance's evolution state. + + The ``Symbolic`` class is the pure-symbolic flavor of DDt — its + ``psi_star`` history slots hold sympy expressions rather than + mesh-variable references, so they are captured by value (sympy + objects are immutable, so list-of-references is faithful). + + Attributes + ---------- + _schema_version + Schema version for cross-UW3-version restore (v1.1 + v1.2). + dt_history + Previous timesteps for variable-Δt BDF; length equals + ``order``. May contain ``None`` entries during startup. + history_initialised + True after the first ``initialise_history`` call. + n_solves_completed + Number of post-solve updates completed (bounded by ``order`` + for effective-order ramp-up). + dt + Current timestep value (most recently set). + psi_star + List of sympy expressions: history slots. Length equals + ``order``. + """ + + _schema_version: int = 1 + dt_history: list = field(default_factory=list) + history_initialised: bool = False + n_solves_completed: int = 0 + dt: Any = None + psi_star: list = field(default_factory=list) + + def _as_float(value): """Extract a plain float from various numeric types (Pint, UWQuantity, etc.).""" if value is None: @@ -472,8 +524,71 @@ def __init__( _update_am_values(self._am_coeffs, 1, self.theta) _update_exp_values(self._exp_coeffs, None, None) + # Register with the active default model as a Snapshottable + # state-bearer. Safe if no model is active. + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except Exception: + pass + return + # ----- Unitary snapshot / restore ----- + # + # Option (B)-style adapter per the design note: state is a derived + # dataclass that surfaces the mutable evolution-tracking attrs. + # The private ``_dt_history`` / ``_history_initialised`` / etc. + # remain the authoritative store; the State dataclass is built on + # read and unpacked on write. + # + # See ``docs/developer/design/in_memory_checkpoint_design.md`` and + # ``src/underworld3/checkpoint/state.py`` for the contract. + + @property + def state(self) -> "DDtSymbolicState": + """Return a snapshot-of-state dataclass for this DDt instance.""" + return DDtSymbolicState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star=list(self.psi_star), + ) + + @state.setter + def state(self, s: "DDtSymbolicState") -> None: + """Write a captured state back. Reconciles derived coefficients.""" + if s._schema_version != DDtSymbolicState._schema_version: + raise ValueError( + f"DDtSymbolicState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtSymbolicState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + if len(s.psi_star) != len(self.psi_star): + raise ValueError( + f"psi_star length mismatch ({len(s.psi_star)} vs " + f"{len(self.psi_star)}); order changed since snapshot?" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + self.psi_star = list(s.psi_star) + # Re-derive BDF/AM coefficients so downstream reads see values + # consistent with the restored primary state without waiting + # for the next update_pre_solve. + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + _update_am_values(self._am_coeffs, self.effective_order, self.theta) + @property def psi_fn(self): r"""Current symbolic expression :math:`\psi` being tracked.""" diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index fbf69167..0981f862 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -245,3 +245,112 @@ def test_swarm_internal_variables_are_not_captured(): assert "material" in captured_names assert not any(n.startswith("DMSwarm") for n in captured_names) + + +# ----- State-as-dataclass contract: Symbolic DDt ----- + + +def _fresh_model_mesh_and_symbolic_ddt(order=2): + import underworld3 as uw + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + ddt = uw.systems.ddt.Symbolic(T.sym, order=order) + return uw, model, mesh, T, ddt + + +def test_symbolic_ddt_registers_with_model(): + """A fresh DDt auto-registers in Model._state_bearers.""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt() + assert ddt in model._state_bearers + + +def test_symbolic_ddt_state_is_a_snapshottable_dataclass(): + """``.state`` returns a SnapshottableState (DDtSymbolicState) with the + expected schema version.""" + from underworld3.checkpoint import SnapshottableState + + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + state = ddt.state + assert isinstance(state, SnapshottableState) + assert state._schema_version == 1 + # Fresh DDt: order-sized dt_history, not initialised, zero solves. + assert state.dt_history == [None, None] + assert state.history_initialised is False + assert state.n_solves_completed == 0 + + +def test_symbolic_ddt_roundtrip_recovers_state(): + """Snapshot mid-trajectory, advance, restore, state equals captured.""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + + # Advance two solves so dt_history fills. + ddt.update_pre_solve(dt=0.1) + ddt.update_post_solve(dt=0.1) + ddt.update_pre_solve(dt=0.2) + ddt.update_post_solve(dt=0.2) + state_pre = ddt.state + # Sanity: history is populated. + assert state_pre.history_initialised is True + assert state_pre.n_solves_completed == 2 + assert state_pre.dt_history == [0.2, 0.1] + + snap = model.snapshot() + + # Mutate: take another solve, dt_history changes. + ddt.update_pre_solve(dt=0.5) + ddt.update_post_solve(dt=0.5) + assert ddt.state.dt_history == [0.5, 0.2] + + model.restore(snap) + + # Primary state is back to captured. + state_post = ddt.state + assert state_post.dt_history == state_pre.dt_history + assert state_post.history_initialised == state_pre.history_initialised + assert state_post.n_solves_completed == state_pre.n_solves_completed + assert state_post.dt == state_pre.dt + + +def test_symbolic_ddt_restore_rejects_wrong_schema_version(): + """Hand-built state with wrong _schema_version is refused on apply.""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + bad_state = ddt.state + bad_state._schema_version = 999 + + with pytest.raises(ValueError, match="schema version"): + ddt.state = bad_state + + +def test_symbolic_ddt_restore_rejects_order_mismatch(): + """Restoring a state captured at a different order raises (programming- + error guard; in practice this shouldn't happen within a single run).""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + bad_state = ddt.state + bad_state.dt_history = [0.1, 0.2, 0.3] # length 3 != order 2 + + with pytest.raises(ValueError, match="dt_history length mismatch"): + ddt.state = bad_state + + +def test_symbolic_ddt_snapshot_is_deep_copy(): + """Mutating the live DDt after snapshot doesn't leak into the + captured state-bearer payload.""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + ddt.update_pre_solve(dt=0.1) + ddt.update_post_solve(dt=0.1) + + snap = model.snapshot() + captured_state = snap.state_bearers[0][1] # (key, state) + captured_dt_history = list(captured_state.dt_history) + + # Scribble the live DDt's internal state — must not leak into snapshot. + ddt._dt_history[0] = -999.0 + ddt._n_solves_completed = 42 + + assert captured_state.dt_history == captured_dt_history + assert captured_state.n_solves_completed != 42 From c58c165da8886ba20cdf876fc2258c2c7fa06f71 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Tue, 12 May 2026 06:33:24 +1000 Subject: [PATCH 08/15] checkpoint: retrofit remaining DDt flavors (Eulerian/SemiLagrangian/Lagrangian/Lagrangian_Swarm) (PR 4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mechanically extends the PR 3 Snapshottable contract from Symbolic to the other four DDt flavors. Each gets: - A flavor-specific State dataclass inheriting from a new _DDtCoreState base that carries the shared dt_history / history_initialised / n_solves_completed / dt fields. - A .state property building the dataclass on read and a .state.setter unpacking on write, re-deriving BDF/AM coefficients so post-restore reads are consistent without waiting for the next update_pre_solve. - Self-registration with the default model in __init__ (try/except for safety when no model is active). State dataclasses (in src/underworld3/systems/ddt.py): - _DDtCoreState: shared fields. Subclasses add psi_star representation specific to the flavor. - DDtSymbolicState (already present, refactored to inherit base): psi_star is a list of sympy expressions. - DDtEulerianState: psi_star is a list of MeshVariables; State carries their clean_names for restore-side verification (the actual DOF arrays travel via the mesh-variable snapshot path). - DDtSemiLagrangianState: same as Eulerian plus optional forcing_star_var_name and with_forcing_history flag for ETD-2 Maxwell-relaxation integration. - DDtLagrangianState / DDtLagrangianSwarmState: psi_star is a list of SwarmVariables on the DDt's swarm; data travels via the swarm-variable path. Notes: - SemiLagrangian's update_pre_solve hardcodes theta=0.5 (the class doesn't accept a theta arg in __init__), so the state setter matches that — not self.theta which doesn't exist. - Lagrangian itself has a pre-existing AttributeError in __init__ (references uw.swarm.UWSwarm which doesn't exist). The retrofit code is in place and follows the same pattern; consumers that construct Lagrangian via the higher-level solver pathways will get .state / .state.setter / registration automatically. The pre-existing bug is out of scope for this PR but worth flagging. - ParameterRegistry retrofit deferred. The class isn't currently wired into Model anywhere in core code (only mentioned in a docstring example). Retrofitting now would be dead code; the retrofit lands together with the real consumer in a follow-up. Tests (3 new on top of PR 3's 6 DDt tests, for 20 total): - Eulerian DDt roundtrip via manual primary-state mutation - SemiLagrangian DDt roundtrip - Lagrangian_Swarm DDt registration + state-type check (no roundtrip — advection needs a velocity-field setup beyond a core unit test) Roundtrips on Eulerian/SemiLagrangian exercise the .state property and .state.setter directly rather than running full projection solves: the BDF/AM coefficient re-derivation happens in the setter, so manual primary-state mutation is sufficient to validate the retrofit logic. 20/20 snapshot tests pass; 41 existing core + DDt tests green. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/systems/ddt.py | 291 ++++++++++++++++++++++++--- tests/test_0007_snapshot_inmemory.py | 124 ++++++++++++ 2 files changed, 390 insertions(+), 25 deletions(-) diff --git a/src/underworld3/systems/ddt.py b/src/underworld3/systems/ddt.py index d21ebebd..c9da783b 100644 --- a/src/underworld3/systems/ddt.py +++ b/src/underworld3/systems/ddt.py @@ -82,31 +82,16 @@ @dataclass -class DDtSymbolicState(SnapshottableState): - """Snapshot of a :class:`Symbolic` DDt instance's evolution state. - - The ``Symbolic`` class is the pure-symbolic flavor of DDt — its - ``psi_star`` history slots hold sympy expressions rather than - mesh-variable references, so they are captured by value (sympy - objects are immutable, so list-of-references is faithful). - - Attributes - ---------- - _schema_version - Schema version for cross-UW3-version restore (v1.1 + v1.2). - dt_history - Previous timesteps for variable-Δt BDF; length equals - ``order``. May contain ``None`` entries during startup. - history_initialised - True after the first ``initialise_history`` call. - n_solves_completed - Number of post-solve updates completed (bounded by ``order`` - for effective-order ramp-up). - dt - Current timestep value (most recently set). - psi_star - List of sympy expressions: history slots. Length equals - ``order``. +class _DDtCoreState(SnapshottableState): + """Common evolution-tracking fields shared by every DDt flavor. + + Each concrete flavor extends this with its own psi_star + representation (sympy expressions for Symbolic; mesh-variable + names for Eulerian / SemiLagrangian; swarm-variable names for + Lagrangian / Lagrangian_Swarm). The actual variable DOF / particle + data lives in the mesh-variable or swarm-variable path of the + snapshot — these State dataclasses carry only the metadata needed + to re-bind on restore. """ _schema_version: int = 1 @@ -114,9 +99,70 @@ class DDtSymbolicState(SnapshottableState): history_initialised: bool = False n_solves_completed: int = 0 dt: Any = None + + +@dataclass +class DDtSymbolicState(_DDtCoreState): + """Snapshot of a :class:`Symbolic` DDt instance's evolution state. + + ``Symbolic`` is the pure-symbolic flavor — ``psi_star`` history + slots hold sympy expressions (immutable), captured by value. + """ + psi_star: list = field(default_factory=list) +@dataclass +class DDtEulerianState(_DDtCoreState): + """Snapshot of an :class:`Eulerian` DDt instance. + + ``psi_star`` is a list of :class:`MeshVariable` objects; their + DOF arrays travel via the mesh-variable snapshot path. This State + only records the variable names for restore-side verification + that the binding still holds. + """ + + psi_star_var_names: list[str] = field(default_factory=list) + + +@dataclass +class DDtSemiLagrangianState(_DDtCoreState): + """Snapshot of a :class:`SemiLagrangian` DDt instance. + + Like :class:`DDtEulerianState`, plus an optional ``forcing_star`` + variable (when ``with_forcing_history=True``) used by ETD-2 + integration of the Maxwell relaxation operator. + """ + + psi_star_var_names: list[str] = field(default_factory=list) + forcing_star_var_name: Optional[str] = None + with_forcing_history: bool = False + + +@dataclass +class DDtLagrangianState(_DDtCoreState): + """Snapshot of a :class:`Lagrangian` DDt instance. + + ``psi_star`` is a list of :class:`SwarmVariable` objects on this + DDt's internal swarm. Their data travels via the swarm-variable + snapshot path. + """ + + psi_star_var_names: list[str] = field(default_factory=list) + + +@dataclass +class DDtLagrangianSwarmState(_DDtCoreState): + """Snapshot of a :class:`Lagrangian_Swarm` DDt instance. + + Same shape as :class:`DDtLagrangianState`; the difference is + operational (Lagrangian creates its own swarm, Lagrangian_Swarm + uses a user-provided one) rather than state-shaped. + """ + + psi_star_var_names: list[str] = field(default_factory=list) + + def _as_float(value): """Extract a plain float from various numeric types (Pint, UWQuantity, etc.).""" if value is None: @@ -897,8 +943,54 @@ def __init__( _update_am_values(self._am_coeffs, 1, self.theta) _update_exp_values(self._exp_coeffs, None, None) + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except Exception: + pass + return + @property + def state(self) -> "DDtEulerianState": + """Return a snapshot-of-state dataclass for this Eulerian DDt.""" + return DDtEulerianState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star_var_names=[ps.clean_name for ps in self.psi_star], + ) + + @state.setter + def state(self, s: "DDtEulerianState") -> None: + if s._schema_version != DDtEulerianState._schema_version: + raise ValueError( + f"DDtEulerianState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtEulerianState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + current_names = [ps.clean_name for ps in self.psi_star] + if s.psi_star_var_names and s.psi_star_var_names != current_names: + raise ValueError( + f"psi_star variable names changed since snapshot: " + f"{s.psi_star_var_names} vs {current_names}" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + _update_am_values(self._am_coeffs, self.effective_order, self.theta) + @property def psi_fn(self): r"""Current symbolic expression :math:`\psi` being tracked.""" @@ -1495,8 +1587,67 @@ def __init__( self.I = uw.maths.Integral(mesh, None) + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except Exception: + pass + return + @property + def state(self) -> "DDtSemiLagrangianState": + return DDtSemiLagrangianState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star_var_names=[ps.clean_name for ps in self.psi_star], + forcing_star_var_name=( + self.forcing_star.clean_name + if self.forcing_star is not None else None + ), + with_forcing_history=bool(self.with_forcing_history), + ) + + @state.setter + def state(self, s: "DDtSemiLagrangianState") -> None: + if s._schema_version != DDtSemiLagrangianState._schema_version: + raise ValueError( + f"DDtSemiLagrangianState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtSemiLagrangianState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + current_names = [ps.clean_name for ps in self.psi_star] + if s.psi_star_var_names and s.psi_star_var_names != current_names: + raise ValueError( + f"psi_star variable names changed since snapshot: " + f"{s.psi_star_var_names} vs {current_names}" + ) + if s.with_forcing_history != bool(self.with_forcing_history): + raise ValueError( + f"with_forcing_history flag differs between snapshot " + f"({s.with_forcing_history}) and current " + f"({self.with_forcing_history})" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + # SemiLagrangian's update_pre_solve uses theta=0.5 directly + # (it doesn't take a theta argument in __init__), so the setter + # matches that. + _update_am_values(self._am_coeffs, self.effective_order, 0.5) + @property def psi_fn(self): r"""Current symbolic expression :math:`\psi` being tracked.""" @@ -2416,8 +2567,53 @@ def __init__( dudt_swarm.populate(fill_param) + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except Exception: + pass + return + @property + def state(self) -> "DDtLagrangianState": + return DDtLagrangianState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star_var_names=[ps.clean_name for ps in self.psi_star], + ) + + @state.setter + def state(self, s: "DDtLagrangianState") -> None: + if s._schema_version != DDtLagrangianState._schema_version: + raise ValueError( + f"DDtLagrangianState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtLagrangianState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + current_names = [ps.clean_name for ps in self.psi_star] + if s.psi_star_var_names and s.psi_star_var_names != current_names: + raise ValueError( + f"psi_star variable names changed since snapshot: " + f"{s.psi_star_var_names} vs {current_names}" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + _update_am_values(self._am_coeffs, self.effective_order, 0.5) + def _object_viewer(self): from IPython.display import Latex, Markdown, display @@ -2717,8 +2913,53 @@ def __init__( _update_bdf_values(self._bdf_coeffs, 1, None, []) _update_am_values(self._am_coeffs, 1, 0.5) + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except Exception: + pass + return + @property + def state(self) -> "DDtLagrangianSwarmState": + return DDtLagrangianSwarmState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star_var_names=[ps.clean_name for ps in self.psi_star], + ) + + @state.setter + def state(self, s: "DDtLagrangianSwarmState") -> None: + if s._schema_version != DDtLagrangianSwarmState._schema_version: + raise ValueError( + f"DDtLagrangianSwarmState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtLagrangianSwarmState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + current_names = [ps.clean_name for ps in self.psi_star] + if s.psi_star_var_names and s.psi_star_var_names != current_names: + raise ValueError( + f"psi_star variable names changed since snapshot: " + f"{s.psi_star_var_names} vs {current_names}" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + _update_am_values(self._am_coeffs, self.effective_order, 0.5) + def _object_viewer(self): from IPython.display import Latex, Markdown, display diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index 0981f862..54b4f76e 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -354,3 +354,127 @@ def test_symbolic_ddt_snapshot_is_deep_copy(): assert captured_state.dt_history == captured_dt_history assert captured_state.n_solves_completed != 42 + + +# ----- State-as-dataclass: other DDt flavors ----- +# +# Construction-side smoke tests + roundtrip. We exercise the .state / +# .state.setter mechanics directly rather than running full solves; +# the BDF/AM coefficient re-derivation happens in the setter, so a +# manual primary-state mutation is enough to validate the retrofit. + + +def test_eulerian_ddt_roundtrip(): + import underworld3 as uw + from underworld3.systems.ddt import DDtEulerianState + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + ddt = uw.systems.ddt.Eulerian( + mesh, T.sym, uw.VarType.SCALAR, degree=1, continuous=True, order=2 + ) + assert ddt in model._state_bearers + assert isinstance(ddt.state, DDtEulerianState) + + # Manually advance state (avoid running real projections). + ddt._dt_history = [0.2, 0.1] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.2 + state_pre = ddt.state + + snap = model.snapshot() + + ddt._dt_history = [0.99, 0.99] + ddt._history_initialised = False + ddt._n_solves_completed = 0 + ddt._dt = None + + model.restore(snap) + + assert ddt.state.dt_history == state_pre.dt_history + assert ddt.state.history_initialised == state_pre.history_initialised + assert ddt.state.n_solves_completed == state_pre.n_solves_completed + assert ddt.state.dt == state_pre.dt + # psi_star names are stable bindings — must match. + assert ddt.state.psi_star_var_names == state_pre.psi_star_var_names + + +def test_semilagrangian_ddt_roundtrip(): + import underworld3 as uw + from underworld3.systems.ddt import DDtSemiLagrangianState + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + V = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + ddt = uw.systems.ddt.SemiLagrangian( + mesh, T.sym, V.sym, uw.VarType.SCALAR, degree=1, continuous=True, order=2 + ) + assert ddt in model._state_bearers + state = ddt.state + assert isinstance(state, DDtSemiLagrangianState) + assert state.with_forcing_history is False + assert state.forcing_star_var_name is None + + ddt._dt_history = [0.3, 0.1] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.3 + state_pre = ddt.state + + snap = model.snapshot() + ddt._dt_history = [None, None] + ddt._history_initialised = False + ddt._n_solves_completed = 0 + model.restore(snap) + + assert ddt.state.dt_history == state_pre.dt_history + assert ddt.state.history_initialised is True + assert ddt.state.n_solves_completed == 2 + + +def test_lagrangian_swarm_ddt_registers_and_state_type(): + """Lagrangian_Swarm must be constructed before swarm.populate; the + retrofit registers it and exposes a typed state. Roundtrip is not + exercised here because advection requires a velocity-field setup + beyond the scope of a core unit test.""" + import underworld3 as uw + from underworld3.systems.ddt import DDtLagrangianSwarmState + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + swarm = uw.swarm.Swarm(mesh) + ddt = uw.systems.ddt.Lagrangian_Swarm( + swarm=swarm, + psi_fn=T.sym, + vtype=uw.VarType.SCALAR, + degree=1, + continuous=True, + order=2, + ) + swarm.populate(fill_param=2) + + assert ddt in model._state_bearers + assert isinstance(ddt.state, DDtLagrangianSwarmState) + assert len(ddt.state.dt_history) == 2 + assert ddt.state.psi_star_var_names # non-empty + + +# Note: uw.systems.ddt.Lagrangian has a pre-existing bug +# (references uw.swarm.UWSwarm which does not exist), so we cannot +# directly construct one for testing. The retrofit code is in place +# and follows the same pattern as the other flavors; consumers that +# construct Lagrangian via the higher-level solver pathways will get +# the .state / .state.setter / registration automatically. From e7b1753b8f69a93ac1345892c9206e937af562da Mon Sep 17 00:00:00 2001 From: lmoresi Date: Tue, 12 May 2026 06:34:38 +1000 Subject: [PATCH 09/15] docs: developer guide for the state-as-dataclass contract Companion to the snapshot toolkit design note (PR 0) and the Snapshottable / DDt retrofit implementation (PRs 3, 4). Audience is developers adding new solver-internal helper classes; the guide covers what goes in a State dataclass, when to use option (B) adapter vs option (C) authoritative-store, what NOT to capture (PETSc handles, bulk arrays already carried by mesh-var / swarm-var paths), how schema versioning is intended to work in v1.1, and a minimal roundtrip test pattern. Closes the doc gap noted in PR 3's commit message; rounds out v1 of the snapshot toolkit. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- docs/developer/guides/state-as-dataclass.md | 212 ++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 docs/developer/guides/state-as-dataclass.md diff --git a/docs/developer/guides/state-as-dataclass.md b/docs/developer/guides/state-as-dataclass.md new file mode 100644 index 00000000..37f34718 --- /dev/null +++ b/docs/developer/guides/state-as-dataclass.md @@ -0,0 +1,212 @@ +# State-as-dataclass — the snapshot contract for solver-internal helpers + +When adding a new solver-internal helper class (a time-derivative +manager, a controller, a convergence tracker, anything carrying +mutable evolution state in Python attributes), you should declare its +mutable state as a dataclass exposed via a `.state` attribute. The +unitary snapshot / restore toolkit +([design note](../design/in_memory_checkpoint_design.md), +implementation in `src/underworld3/checkpoint/`) discovers state +through this attribute automatically — no per-class registration of +which attributes are "state" required. + +## Why a contract + +Without a uniform contract, every helper invents its own way of +persisting mutable state — or doesn't, and silently loses information +across snapshot / restore. The audit that motivated this work found +several classes (DDt, parameter mutation history, swarm population +counter, ...) each with bespoke state-tracking and no path to +snapshot. The contract makes the obligation explicit: if your class +has mutable state, declare it as a dataclass, end of story. + +Side benefits of the contract beyond snapshot: +- `obj_a.state == obj_b.state` becomes a meaningful comparison — useful + for regression-testing solver-internal behaviour. +- `repr(obj.state)` is automatic and useful for debugging. +- Schema versioning is tractable: a `_schema_version` field on each + State dataclass plus a migration registry (v1.1) handles + cross-version compatibility for on-disk snapshots. + +## The interface + +```python +from underworld3.checkpoint import SnapshottableState + +# Inherit from SnapshottableState — gives you the _schema_version field. +@dataclass +class MyHelperState(SnapshottableState): + # _schema_version: int = 1 # inherited; override on schema change + counter: int = 0 + history: list[float] = field(default_factory=list) + config_name: str = "" +``` + +The host class exposes the dataclass via a `.state` attribute +(property is fine; stored attribute is fine). + +```python +class MyHelper(uw_object): # uw_object gives you instance_number + def __init__(self, ...): + super().__init__() + # ... set up your state ... + + # Self-register so Model.snapshot() discovers you. + try: + import underworld3 as _uw + _uw.get_default_model()._register_state_bearer(self) + except Exception: + pass + + @property + def state(self) -> MyHelperState: + return MyHelperState( + counter=self._counter, + history=list(self._history), + config_name=self._config_name, + ) + + @state.setter + def state(self, s: MyHelperState) -> None: + if s._schema_version != MyHelperState._schema_version: + raise ValueError("schema mismatch") + self._counter = s.counter + self._history = list(s.history) + self._config_name = s.config_name + # If your class has *derived* state (caches, downstream + # coefficients, ...) recompute it here so post-restore reads + # are consistent without waiting for the next solve. +``` + +The protocol check is structural: any object exposing a `.state` +attribute that is a `SnapshottableState` instance is `Snapshottable`. + +## Option (B) vs (C): stored vs derived dataclass + +The design note discusses two implementation styles. Both satisfy the +protocol; pick based on whether the class is new or being retrofitted. + +### Option (C): state is the authoritative store. *Prefer for new code.* + +```python +class MyNewHelper: + def __init__(self, ...): + self.state = MyHelperState(counter=0, history=[]) + + def step(self, dx): + self.state.counter += 1 + self.state.history.append(dx) +``` + +Every mutation site reads/writes `self.state.` directly. The +dataclass *is* the storage. Self-documenting; mistakes (forgetting to +add a new field to the state object) are obvious because adding state +involves adding a field to the dataclass. + +### Option (B): state is a derived view over private attrs. *Use for retrofits.* + +```python +class MyExistingHelper: + def __init__(self, ...): + self._counter = 0 + self._history = [] + # ... legacy private attrs throughout the class ... + + @property + def state(self) -> MyHelperState: + return MyHelperState(counter=self._counter, history=list(self._history)) + + @state.setter + def state(self, s): + self._counter = s.counter + self._history = list(s.history) +``` + +The existing private attrs stay as-is; `.state` is a façade. Less +invasive than option (C) because the existing call sites don't +change — only the new accessor is added. + +The five `DDt` flavors in `src/underworld3/systems/ddt.py` (`Symbolic`, +`Eulerian`, `SemiLagrangian`, `Lagrangian`, `Lagrangian_Swarm`) use +option (B) for this reason: they predate the contract, and rewriting +every mutation site would be churn without behavior change. + +## What goes in the State dataclass + +**Yes:** +- Mutable evolution-tracking state — anything that changes after + construction and affects subsequent behaviour. Counters, history + buffers, current-step values, mutation logs. +- Names / stable IDs of bound objects (mesh-variable names, swarm + names) — restore uses these to verify the binding still holds. +- Configuration flags whose state is meaningful at snapshot time + (`with_forcing_history`, `recycle_rate`, ...) — they help restore + detect a mid-run reconfiguration that would invalidate the snapshot. + +**No:** +- Live PETSc Vec, DM, or solver handles. Tokens must stay plain + Python / numpy so they survive DM lifecycle changes (and so v1.1's + on-disk serialisation works). +- Constructor-time arguments that never change (mesh reference, + variable type, polynomial degree). Re-derive on restore from the + surviving wrapper. +- Bulk numerical data like mesh-variable DOFs or swarm-variable + arrays. Those travel via the dedicated mesh-var / swarm-var + snapshot paths. +- Caches, derived coefficients, anything you can recompute from + primary state in the `.state` setter. + +## Schema versioning + +`_schema_version` exists for a v1.1 / v1.2 feature: on-disk snapshots +that survive across UW3 versions. When you change the shape of a +State dataclass (rename a field, add a required field, change a +type), bump the version and add a migration entry. Within a single +process (v1 in-memory only), the version is checked for strict +equality — any mismatch is a programming error. + +The migration registry itself is v1.1 work (item 6 of the design +note). For now: define `_schema_version: int = 1` on every State +dataclass and don't worry about migrations until on-disk lands. + +## Where to put your State dataclass + +Next to the host class. `DDtSymbolicState` lives next to +`Symbolic` in `src/underworld3/systems/ddt.py`. This keeps the +dataclass and the class it describes co-located; changes to one show +up in the diff against the other. + +## Testing + +A typical snapshot test for a state-bearing class: + +```python +def test_my_helper_roundtrip(): + uw.reset_default_model() + model = uw.get_default_model() + h = MyHelper(...) + + # Advance state. + h.step(0.1) + h.step(0.2) + state_pre = h.state + + snap = model.snapshot() + + # Mutate. + h.step(0.5) + + model.restore(snap) + + # Verify primary state recovered. + assert h.state == state_pre +``` + +The dataclass `__eq__` makes the final assertion a one-liner. + +## Related + +- [Design note](../design/in_memory_checkpoint_design.md) — full design rationale, scope, and roadmap. +- `src/underworld3/checkpoint/state.py` — protocol definitions. +- `src/underworld3/checkpoint/snapshot.py` — capture / restore orchestration. +- `src/underworld3/systems/ddt.py` — five working examples (option (B) retrofits). From b1791839fea5ac34cef4c9e8ba341f5f083cdd2a Mon Sep 17 00:00:00 2001 From: lmoresi Date: Tue, 12 May 2026 10:01:11 +1000 Subject: [PATCH 10/15] =?UTF-8?q?fix(ddt):=20Lagrangian.=5F=5Finit=5F=5F?= =?UTF-8?q?=20=E2=80=94=20uw.swarm.UWSwarm=20=E2=86=92=20uw.swarm.Swarm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Lagrangian DDt flavor has been unconstructible since 2025-07-07 (commit 0778b7d, "Fix uw.function.evaluate / eliminate evalf"), which typo'd uw.swarm.Swarm(mesh) as uw.swarm.UWSwarm(mesh) while editing nearby code. UWSwarm does not exist — never has — so every direct construction of Lagrangian since that commit has died with AttributeError. Higher-level solver pathways that wrap Lagrangian were presumably also broken; consumers may have silently been using Lagrangian_Swarm or another flavor as a workaround. One-character fix: revert that line. No other UWSwarm references exist in the tree. This bug surfaced during the snapshot toolkit work (PR 4) when the state-as-dataclass retrofit included a Lagrangian roundtrip test that couldn't run. With the fix in place, the test now runs and passes — included in the same commit so the bug fix and the proof it works land together. Should be cherry-picked to development; the typo is unrelated to the snapshot feature work it surfaced from. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/systems/ddt.py | 2 +- tests/test_0007_snapshot_inmemory.py | 44 ++++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/underworld3/systems/ddt.py b/src/underworld3/systems/ddt.py index c9da783b..cc13612e 100644 --- a/src/underworld3/systems/ddt.py +++ b/src/underworld3/systems/ddt.py @@ -2528,7 +2528,7 @@ def __init__( super().__init__() # create a new swarm to manage here - dudt_swarm = uw.swarm.UWSwarm(mesh) + dudt_swarm = uw.swarm.Swarm(mesh) self.mesh = mesh self.swarm = dudt_swarm diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index 54b4f76e..80e613c5 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -441,6 +441,44 @@ def test_semilagrangian_ddt_roundtrip(): assert ddt.state.n_solves_completed == 2 +def test_lagrangian_ddt_roundtrip(): + """Lagrangian creates its own internal swarm; the fix in this PR + restored uw.swarm.Swarm in __init__ (was a typo'd UWSwarm).""" + import underworld3 as uw + from underworld3.systems.ddt import DDtLagrangianState + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + V = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + ddt = uw.systems.ddt.Lagrangian( + mesh=mesh, psi_fn=T.sym, V_fn=V.sym, + vtype=uw.VarType.SCALAR, degree=1, continuous=True, order=2, + ) + assert ddt in model._state_bearers + assert isinstance(ddt.state, DDtLagrangianState) + + ddt._dt_history = [0.2, 0.1] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.2 + state_pre = ddt.state + + snap = model.snapshot() + ddt._dt_history = [None, None] + ddt._history_initialised = False + ddt._n_solves_completed = 0 + model.restore(snap) + + assert ddt.state.dt_history == state_pre.dt_history + assert ddt.state.history_initialised is True + assert ddt.state.n_solves_completed == 2 + assert ddt.state.psi_star_var_names == state_pre.psi_star_var_names + + def test_lagrangian_swarm_ddt_registers_and_state_type(): """Lagrangian_Swarm must be constructed before swarm.populate; the retrofit registers it and exposes a typed state. Roundtrip is not @@ -472,9 +510,3 @@ def test_lagrangian_swarm_ddt_registers_and_state_type(): assert ddt.state.psi_star_var_names # non-empty -# Note: uw.systems.ddt.Lagrangian has a pre-existing bug -# (references uw.swarm.UWSwarm which does not exist), so we cannot -# directly construct one for testing. The retrofit code is in place -# and follows the same pattern as the other flavors; consumers that -# construct Lagrangian via the higher-level solver pathways will get -# the .state / .state.setter / registration automatically. From 795fd987d3ba35a04277650f92e7f481db4c2dd8 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Tue, 12 May 2026 11:45:33 +1000 Subject: [PATCH 11/15] test: genuine back-stepping demonstration end-to-end MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Every previous snapshot test is unit-style — build a thing, snapshot, scribble, restore, check equality. None exercise the actual use case that motivated the toolkit: detect a bad step, snap back, retry. This commit adds one focused end-to-end test covering the canonical adaptive-Δt CFL workflow. The test simultaneously exercises all three captured state surfaces in one realistic story: - A swarm with an outward-radial velocity field carries particles outward at known speeds. - A material variable on the swarm carries a per-particle marker (initial x-coord), so we can prove particle identity is recovered, not just particle count. - A Symbolic DDt accumulates BDF history (manually advanced past startup), so the state-bearer / state-as-dataclass path also gets exercised. The flow: 1. Snapshot before the speculative step. 2. Take a candidate Δt = 0.5 → max displacement ~0.27, ~6× the cell radius. CFL violated; the consumer's check trips. 3. model.restore(snap): particle positions, material data, and DDt history all roll back to the snapshot point. 4. Retry at Δt = 0.05 → max displacement ~0.033, sub-cell. CFL satisfied; state evolves cleanly. The Δt and threshold values come from a probe run on the same mesh (min_radius ≈ 0.044; |V| ≤ 0.71 at corners), so the CFL violation on the candidate step is a real physical observation rather than a parameter tweak. Smaller dt → strictly smaller displacement gives a robust ratio-based assertion that doesn't rely on exact numbers. This is the test pattern that consumers (RK staging, adaptive Δt, predictor-corrector, regime-change feeling-out) will adapt. v1 of the snapshot toolkit is genuine, not just unit-tested. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- tests/test_0007_snapshot_inmemory.py | 112 +++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index 80e613c5..c8a1c128 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -479,6 +479,118 @@ def test_lagrangian_ddt_roundtrip(): assert ddt.state.psi_star_var_names == state_pre.psi_star_var_names +# ----- End-to-end back-stepping demonstration ----- +# +# Everything above this comment is unit-style: build a thing, snapshot, +# scribble, restore, check equality. This block exercises the toolkit's +# actual reason for existing: a *real* time-stepping use case where the +# consumer takes a step, detects it was bad, snapshots back, and retries +# with smaller Δt. The pattern is canonical adaptive-Δt CFL control; +# the snapshot mechanism is the thing that makes "snap back" possible +# without manually unwinding mesh / swarm / DDt state. + + +def test_backstepping_cfl_recovery_end_to_end(): + """Canonical adaptive-Δt back-step demonstration. + + Set up a swarm advecting in a known velocity field, with a + material variable carried along and a Symbolic DDt accumulating + BDF history. Take one too-large Δt step → CFL violation + (max-particle-displacement exceeds the mesh cell radius). Detect + it. Restore the snapshot. Retry with a smaller Δt → CFL satisfied, + state evolves cleanly. The full triple of state (swarm positions, + material variable, DDt history) is recovered on restore. + """ + import underworld3 as uw + import sympy + import numpy as np + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 8.0 + ) + + # Outward-radial velocity from the box centre. |V| ranges from 0 + # at the centre to ~0.71 at the corners — pick Δt to give a + # genuine CFL violation rather than tweak parameters to fit. + x, y = mesh.X + V_fn = sympy.Matrix([[x - 0.5, y - 0.5]]).T + + swarm = uw.swarm.Swarm(mesh) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + coords_initial = swarm._particle_coordinates.data.copy() + material.data[:, 0] = coords_initial[:, 0] # carry x as marker + material_initial = np.asarray(material.data).copy() + + # A separate DDt manages BDF history for a scalar field on the + # mesh. Advance it manually past startup so its captured state is + # non-trivial. + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + ddt._dt_history = [0.05, 0.05] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.05 + ddt_state_initial = ddt.state + + # The user's CFL threshold: a particle moving more than one cell + # radius in a single step has crossed an element. min_radius is + # the standard UW3 cell-size proxy. + cfl_threshold = mesh.get_min_radius() + + # Take the snapshot *before* the speculative step. Everything that + # will be touched gets captured. + snap = model.snapshot() + + # Speculative step at the candidate Δt. Bigger than the user + # thinks is safe — they'll check after and back-step if it isn't. + candidate_dt = 0.5 + swarm.advection(V_fn, delta_t=candidate_dt, step_limit=False) + + # CFL check: max displacement among local particles. + coords_after_bad = swarm._particle_coordinates.data + max_disp_bad = np.max( + np.linalg.norm(coords_after_bad - coords_initial, axis=1) + ) + assert max_disp_bad > cfl_threshold, ( + f"speculative step at dt={candidate_dt} should violate CFL " + f"(max_disp={max_disp_bad:.4f} vs threshold {cfl_threshold:.4f})" + ) + + # Back-step. Everything captured is brought back to the snapshot + # point — swarm positions, the material variable carried with the + # swarm, and the DDt's BDF history. + model.restore(snap) + + assert np.allclose(swarm._particle_coordinates.data, coords_initial), ( + "particle positions did not roll back after restore" + ) + assert np.allclose(np.asarray(material.data), material_initial), ( + "swarm-variable data did not roll back after restore" + ) + assert ddt.state.dt_history == ddt_state_initial.dt_history, ( + "DDt history did not roll back after restore" + ) + assert ddt.state.n_solves_completed == ddt_state_initial.n_solves_completed + + # Retry with a smaller Δt. CFL now satisfied. + retry_dt = candidate_dt / 10.0 + swarm.advection(V_fn, delta_t=retry_dt, step_limit=False) + + coords_after_good = swarm._particle_coordinates.data + max_disp_good = np.max( + np.linalg.norm(coords_after_good - coords_initial, axis=1) + ) + assert max_disp_good < cfl_threshold, ( + f"retry at dt={retry_dt} should satisfy CFL " + f"(max_disp={max_disp_good:.4f} vs threshold {cfl_threshold:.4f})" + ) + # Sanity: smaller dt produced strictly smaller displacement. + assert max_disp_good < max_disp_bad / 5.0 + + def test_lagrangian_swarm_ddt_registers_and_state_type(): """Lagrangian_Swarm must be constructed before swarm.populate; the retrofit registers it and exposes a typed state. Roundtrip is not From 4702561b3a5c0c4fdf8b561c8a4f7f08f547461f Mon Sep 17 00:00:00 2001 From: lmoresi Date: Mon, 18 May 2026 22:41:20 +1000 Subject: [PATCH 12/15] =?UTF-8?q?test:=20bit-identical=20continuation=20?= =?UTF-8?q?=E2=80=94=20the=20core=20stash-for-steps=20guarantee?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Every prior test proved *state equality after restore*. That is necessary but not the guarantee a backtracking consumer actually relies on. The real guarantee — "git stash for steps": a discarded speculative step leaves zero trace after restore + continuation — was untested. This commit adds it, asserted bit-for-bit (np.array_equal, no tolerance). Two tests, both with a live swarm + driven mesh variable + Symbolic DDt so the mesh -> swarm -> state-bearer restore ordering is exercised together: - test_continuation_deterministic_after_restore: snapshot S -> K steps -> A; restore(S) -> K steps -> B. A == B bit-for-bit. Proves restore leaves no residual state that perturbs subsequent evolution. - test_continuation_bit_identical_across_stash_and_recover: control: S -> K good steps -> ctrl stash: S -> disruptive 10x-dt step -> restore(S) -> same K good steps -> stash ctrl == stash bit-for-bit. The regretted step leaves no trace. This also closes the #3 concern (does Mesh.apply_snapshot_payload's _deform_mesh call disturb a registered swarm before the swarm restore runs?). Both tests have a swarm and a mesh variable live through restore; the mesh restore calls _deform_mesh with unchanged coords, then swarm restore, then DDt restore. If the mesh restore perturbed the swarm, continuation would not be bit-identical. It is. For v1 scope this fully covers the _deform_mesh-on-restore path, because a *deformed* mesh (bumped _mesh_version) is refused on restore anyway — the only path that runs _deform_mesh on restore is the same-coords path these tests now cover. Remaining production blockers (unchanged by this commit): parallel (MPI) is still untested and the swarm rebuild deliberately bypasses migration; no real-solver test; memory cost unmeasured. 24 snapshot tests, 24 regression tests, all green. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- tests/test_0007_snapshot_inmemory.py | 147 +++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index c8a1c128..e7e5bb0d 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -622,3 +622,150 @@ def test_lagrangian_swarm_ddt_registers_and_state_type(): assert ddt.state.psi_star_var_names # non-empty +# ----- Bit-identical continuation (the core production guarantee) ----- +# +# Everything above proves *state equality after restore*. That is +# necessary but not the actual guarantee a backtracking consumer +# relies on. The guarantee is: after restore, *continuing the +# simulation* reproduces the trajectory of a run that never took the +# discarded step. These two tests assert that, bit-for-bit +# (np.array_equal — no tolerance), with a swarm + mesh variable + +# Symbolic DDt all live so the mesh -> swarm -> state-bearer restore +# ordering is exercised together. + + +def _build_continuation_fixture(): + """Mesh + swarm(+material) + a driven mesh variable + Symbolic DDt. + + Returns everything needed to run a deterministic step loop. + """ + import underworld3 as uw + import sympy + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 6.0 + ) + x_sym, y_sym = mesh.X + V_fn = sympy.Matrix([[x_sym - 0.5, y_sym - 0.5]]).T + + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + T.array[:, 0, 0] = 0.0 + + swarm = uw.swarm.Swarm(mesh) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + material.data[:, 0] = np.linalg.norm( + swarm._particle_coordinates.data - 0.5, axis=1 + ) + + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + + return uw, model, mesh, V_fn, T, swarm, material, ddt + + +def _step(uw, V_fn, T, swarm, ddt, dt): + """One deterministic step: advect swarm, evolve T by a fixed rule, + advance the DDt history. No solver, no randomness.""" + ddt.update_pre_solve(dt) + swarm.advection(V_fn, delta_t=dt, step_limit=False) + # Deterministic, history-free field update so T carries evolving + # state through the mesh-variable snapshot path. + T.array[:, 0, 0] = T.array[:, 0, 0] + dt + ddt.update_post_solve(dt) + + +def _capture_full_state(T, swarm, material, ddt): + """Everything that must match for bit-identical continuation.""" + return { + "T": np.asarray(T.array[...]).copy(), + "coords": swarm._particle_coordinates.data.copy(), + "material": np.asarray(material.data).copy(), + "dt_history": list(ddt.state.dt_history), + "n_solves": ddt.state.n_solves_completed, + "ddt_dt": ddt.state.dt, + } + + +def _assert_bit_identical(a, b, label): + assert np.array_equal(a["T"], b["T"]), f"{label}: T differs" + assert np.array_equal(a["coords"], b["coords"]), ( + f"{label}: swarm coords differ" + ) + assert np.array_equal(a["material"], b["material"]), ( + f"{label}: swarm material differs" + ) + assert a["dt_history"] == b["dt_history"], ( + f"{label}: DDt dt_history differs ({a['dt_history']} vs " + f"{b['dt_history']})" + ) + assert a["n_solves"] == b["n_solves"], f"{label}: DDt n_solves differs" + assert a["ddt_dt"] == b["ddt_dt"], f"{label}: DDt dt differs" + + +def test_continuation_deterministic_after_restore(): + """snapshot S -> K steps -> A; restore(S) -> K steps -> B. + A and B must be bit-identical. Proves restore leaves no residual + state that perturbs subsequent evolution.""" + uw, model, mesh, V_fn, T, swarm, material, ddt = ( + _build_continuation_fixture() + ) + + # Advance to a non-trivial state before snapshotting (fill DDt + # history, move particles off their lattice). + for _ in range(3): + _step(uw, V_fn, T, swarm, ddt, 0.05) + + snap = model.snapshot() + + # Branch A: K steps straight from S. + for _ in range(5): + _step(uw, V_fn, T, swarm, ddt, 0.05) + state_A = _capture_full_state(T, swarm, material, ddt) + + # Branch B: restore S, then the identical K steps. + model.restore(snap) + for _ in range(5): + _step(uw, V_fn, T, swarm, ddt, 0.05) + state_B = _capture_full_state(T, swarm, material, ddt) + + _assert_bit_identical(state_A, state_B, "deterministic-continuation") + + +def test_continuation_bit_identical_across_stash_and_recover(): + """The real 'git stash for steps' guarantee: + + control: S -> K good steps -> ctrl + stash: S -> bad disruptive step -> restore(S) + -> same K good steps -> stash + + ctrl and stash must be bit-identical: the discarded step must + leave no trace whatsoever after restore + continuation.""" + uw, model, mesh, V_fn, T, swarm, material, ddt = ( + _build_continuation_fixture() + ) + + for _ in range(3): + _step(uw, V_fn, T, swarm, ddt, 0.05) + + snap = model.snapshot() + + # Control: K good steps from S. + for _ in range(5): + _step(uw, V_fn, T, swarm, ddt, 0.05) + ctrl = _capture_full_state(T, swarm, material, ddt) + + # Stash scenario: back to S, take a deliberately disruptive step + # (10x Δt — large advection, big T jump, DDt history shift), then + # discard it via restore and run the intended K good steps. + model.restore(snap) + _step(uw, V_fn, T, swarm, ddt, 0.5) # the regretted step + model.restore(snap) + for _ in range(5): + _step(uw, V_fn, T, swarm, ddt, 0.05) + stash = _capture_full_state(T, swarm, material, ddt) + + _assert_bit_identical(ctrl, stash, "stash-and-recover") + + From af2748d6f71cbf2397740d17d4900e397f0707d8 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Tue, 19 May 2026 13:04:06 +1000 Subject: [PATCH 13/15] =?UTF-8?q?test(parallel):=20MPI=20snapshot/restore?= =?UTF-8?q?=20=E2=80=94=20exact=20reconstruction=20confirmed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The one real production blocker was "works everywhere" — i.e. correct under MPI. This adds a parallel ptest and confirms the design intent: swarm restore is a per-rank reconstruction, not a redistribution, so the global state is exactly rebuilt under cross-rank migration provided the rank count is unchanged (the documented v1 scope). ptest_0007_snapshot_inmemory.py (mesh + swarm + per-particle global-id tag + material + Symbolic DDt; rotation field that circulates particles across the strip partition). Three collective properties, asserted on rank 0: P1 restore recovers the exact global particle count. The disruptive step is deliberately *allowed* to lose particles across ranks (advect out / clip) — that is exactly the failure a stash-and-restore exists to undo. P2 exact reconstruction: gather (gid, x, y, material) from every rank, sort by global id, np.array_equal pre-step vs post-restore. Order- and rank-independent — the real proof that per-rank reconstruction yields the correct global state. P3 bit-identical continuation across a stash, in parallel. Results: -np 1 : 2052 particles, all properties pass. -np 3 : 2013 particles; disruptive step loses 28 across ranks; restore recovers all 2013 exactly; P2/P3 bit-for-bit. -np 4 : 2006 particles; disruptive step loses 35 across ranks; restore recovers all 2006 exactly; P2/P3 bit-for-bit. The genuinely strong result: the toolkit demonstrably recovers from real cross-rank particle loss — the exact production scenario it exists for — with bit-identical continuation afterwards. Registered in mpi_runner.sh at -np 1 / 3 (uneven) / 4. Production-blocker status: parallel correctness now confirmed (was the gate). Remaining items are confidence/hardening only — a real-solver test (SNES state is negligible: previous solution travels via MeshVariable, already captured) and the accepted, documented in-memory memory cost (mitigation: route through the v1.1 on-disk backend via a flag). Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- tests/parallel/mpi_runner.sh | 7 + .../parallel/ptest_0007_snapshot_inmemory.py | 187 ++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 tests/parallel/ptest_0007_snapshot_inmemory.py diff --git a/tests/parallel/mpi_runner.sh b/tests/parallel/mpi_runner.sh index eed4a0c9..3b479dbe 100755 --- a/tests/parallel/mpi_runner.sh +++ b/tests/parallel/mpi_runner.sh @@ -24,3 +24,10 @@ mpirun -np 4 $PYTHON ./ptest_002_projection.py #mpirun -np 1 $PYTHON ./ptest_003_swarm_projection.py #echo "ptest 003 -np 4" #mpirun -np 4 $PYTHON ./ptest_003_swarm_projection.py + +echo "ptest 0007 snapshot in-memory -np 1" +mpirun -np 1 $PYTHON ./ptest_0007_snapshot_inmemory.py +echo "ptest 0007 snapshot in-memory -np 3 (uneven partition)" +mpirun -np 3 $PYTHON ./ptest_0007_snapshot_inmemory.py +echo "ptest 0007 snapshot in-memory -np 4" +mpirun -np 4 $PYTHON ./ptest_0007_snapshot_inmemory.py diff --git a/tests/parallel/ptest_0007_snapshot_inmemory.py b/tests/parallel/ptest_0007_snapshot_inmemory.py new file mode 100644 index 00000000..d533d046 --- /dev/null +++ b/tests/parallel/ptest_0007_snapshot_inmemory.py @@ -0,0 +1,187 @@ +"""Parallel (MPI) test of the in-memory snapshot toolkit. + +The single open production blocker for the snapshot toolkit was +"works everywhere" — i.e. correct under MPI. The design intent is +that swarm restore is a per-rank *reconstruction* (each rank clears +its local particles and re-adds the per-rank set it captured), not a +redistribution, so the global state is exactly reconstructed +regardless of any intervening cross-rank migration, provided the rank +count is unchanged. This script confirms that. + +Run (4 ranks exercises cross-rank migration properly): + + cd tests/parallel + mpirun -np 4 python ./ptest_0007_snapshot_inmemory.py + +Asserts (all collective, checked on rank 0): + + 1. Restore recovers the exact global particle count. The disruptive + step is deliberately allowed to *lose* particles across ranks + (advect out / clip) — that is exactly the failure stash-and- + restore exists to undo. The guarantee is that restore brings the + global count back to its pre-step value regardless. + 2. Exact reconstruction: gather every particle's (global-id, x, y, + material) across all ranks, sort by global id; the post-restore + sorted table equals the pre-step sorted table bit-for-bit. + Order- and rank-independent — this is the real proof that + per-rank reconstruction yields the correct global state under + cross-rank migration. + 3. Bit-identical continuation across a stash: a control run and a + run that took a disruptive step then restored and continued + produce bit-identical global sorted state and DDt history. +""" + +import numpy as np +import sympy +from mpi4py import MPI + +import underworld3 as uw + +comm = MPI.COMM_WORLD +rank = uw.mpi.rank +size = uw.mpi.size + + +def build(): + uw.reset_default_model() + model = uw.get_default_model() + # Wide box so a strip-partition genuinely splits particles across + # ranks; rotation field circulates them across the partition. + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(4.0, 1.0), cellSize=1.0 / 6.0 + ) + x_sym, y_sym = mesh.X + # Rotation about the box centre (2.0, 0.5): particles circulate, + # crossing the vertical rank-partition boundaries. + V_fn = sympy.Matrix([[-(y_sym - 0.5), 0.25 * (x_sym - 2.0)]]).T + + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + T.array[:, 0, 0] = 0.0 + + swarm = uw.swarm.Swarm(mesh) + gid = swarm.add_variable("gid", 1, dtype=float) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + + # Globally-unique, migration-stable particle id. Swarm variables + # travel with their particle through migration and through + # snapshot/restore, so this is a durable identity tag. + local_n = swarm.dm.getLocalSize() + counts = comm.allgather(local_n) + offset = int(np.sum(counts[:rank])) + gid.data[:, 0] = offset + np.arange(local_n, dtype=float) + material.data[:, 0] = swarm._particle_coordinates.data[:, 0] + + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + return uw, model, mesh, V_fn, T, swarm, gid, material, ddt + + +def step(uw, V_fn, T, swarm, ddt, dt): + ddt.update_pre_solve(dt) + swarm.advection(V_fn, delta_t=dt, step_limit=False) + T.array[:, 0, 0] = T.array[:, 0, 0] + dt + ddt.update_post_solve(dt) + + +def global_sorted_particles(swarm, gid, material): + """Gather (gid, x, y, material) from all ranks, sorted by gid. + + Order- and rank-independent canonical view of the whole swarm. + """ + g = gid.data[:, 0].copy() + coords = swarm._particle_coordinates.data.copy() + m = material.data[:, 0].copy() + local = np.column_stack([g, coords[:, 0], coords[:, 1], m]) + gathered = comm.allgather(local) + full = np.vstack([a for a in gathered if a.size]) if any( + a.size for a in gathered + ) else np.empty((0, 4)) + order = np.argsort(full[:, 0], kind="stable") + return full[order] + + +def main(): + uw, model, mesh, V_fn, T, swarm, gid, material, ddt = build() + + # Warm up: a few steps so particles have genuinely migrated across + # ranks before we snapshot. + for _ in range(3): + step(uw, V_fn, T, swarm, ddt, 0.1) + + pre = global_sorted_particles(swarm, gid, material) + pre_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + pre_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) + + snap = model.snapshot() + + # --- Property 1 + 2: a migration-inducing step, then restore --- + step(uw, V_fn, T, swarm, ddt, 0.3) # bigger dt -> more migration + mid_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + model.restore(snap) + post = global_sorted_particles(swarm, gid, material) + post_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + post_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) + + exact = np.array_equal(pre, post) + ddt_ok = pre_ddt == post_ddt + + # --- Property 3: bit-identical continuation across a stash --- + snap2 = model.snapshot() + for _ in range(4): + step(uw, V_fn, T, swarm, ddt, 0.1) + ctrl = global_sorted_particles(swarm, gid, material) + ctrl_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) + + model.restore(snap2) + step(uw, V_fn, T, swarm, ddt, 0.5) # the regretted step + model.restore(snap2) + for _ in range(4): + step(uw, V_fn, T, swarm, ddt, 0.1) + stash = global_sorted_particles(swarm, gid, material) + stash_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) + + cont_exact = np.array_equal(ctrl, stash) + cont_ddt_ok = ctrl_ddt == stash_ddt + + if rank == 0: + lost = pre_count - mid_count + print(f"[ranks={size}] particles total = {pre_count}", flush=True) + print( + f" disruptive step global count: {pre_count} -> {mid_count} " + f"-> {post_count} ({lost} particle(s) lost by the step, " + f"recovered by restore)", + flush=True, + ) + print(f" P1 restore recovers exact count: " + f"{pre_count == post_count}", flush=True) + print(f" P2 exact reconstruction: {exact}", flush=True) + print(f" P2 DDt state restored: {ddt_ok}", flush=True) + print(f" P3 bit-identical continuation: {cont_exact}", flush=True) + print(f" P3 DDt continuation identical: {cont_ddt_ok}", flush=True) + + # The disruptive step is *allowed* to lose particles across + # ranks — that is precisely the failure a stash-and-restore + # exists to undo. The guarantee is that restore brings the + # global count back exactly and every particle back to the + # right place (P2), regardless of what the step did. + assert pre_count == post_count, ( + f"restore did not recover the exact global particle count: " + f"pre={pre_count} post={post_count} (mid={mid_count})" + ) + assert exact, "swarm not exactly reconstructed after restore" + assert ddt_ok, "DDt state not restored" + assert cont_exact, ( + "continuation after stash is not bit-identical to control" + ) + assert cont_ddt_ok, "DDt continuation not bit-identical" + if lost == 0: + print( + " (note: this run's disruptive step happened not to " + "lose particles; the recovery guarantee still holds)", + flush=True, + ) + print(f"[ranks={size}] PASS", flush=True) + + +if __name__ == "__main__": + main() From 3efc31bd767be179032ade0a02645f0636e3e5bf Mon Sep 17 00:00:00 2001 From: lmoresi Date: Tue, 19 May 2026 14:58:49 +1000 Subject: [PATCH 14/15] =?UTF-8?q?test:=20real-solver=20confidence=20?= =?UTF-8?q?=E2=80=94=20discarded=20step=20is=20bit-exact,=20restore=20floo?= =?UTF-8?q?r=20characterised?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the last confidence gap: snapshot/restore driven by an actual PETSc solver (AdvDiffusion, which carries an internal SemiLagrangian DDt with an auxiliary projection SNES + nodal trace-back swarm), through the stash-and-recover loop. Investigation findings (each verified by a standalone diagnostic): * AdvDiffusion solve is bit-deterministic — two independent identical runs with no snapshot are np.array_equal (max|d| = 0.0). So any drift introduced by snapshot/restore is a real fidelity question, not solver noise. * restore() recovers the primary solution field T bit-exactly. * THE core "git stash for steps" guarantee holds bit-for-bit even through real solves: restore -> regretted absurd-dt solve -> restore -> K solves is np.array_equal to restore -> K solves The discarded step leaves zero trace (B == C, max|d| = 0.0). * The only residual is restore's reproducibility floor against a *never-snapshotted* control: ~7e-7 here. Mechanism: restore resyncs fields through gvec->lvec rather than reproducing the solver-produced lvec exactly; the implicit diffusion operator amplifies that to solver-tolerance level over steps. This is NOT contamination from the discarded step (proven by B == C), it is the cost of round-tripping through the snapshot representation, within solver tolerance, and consistent with the design intent that auxiliary solver state is intentionally not captured. Three tests encode exactly this (no overclaiming): - test_realsolver_restore_recovers_solution_field (T np.array_equal) - test_realsolver_regretted_step_leaves_no_trace (B == C, bit-exact) - test_realsolver_continuation_within_solver_tolerance (vs never-stashed control: < 1e-5, asserted explicitly non-bit-exact so the test tightens itself if the floor is ever eliminated) Honest production statement: discarding a bad step is bit-exact even through real solvers; recovering to a never-stashed control is within solver tolerance. Both are correct for the "git stash for steps" use case. 51 tests pass (24 serial snapshot + 3 real-solver + 24 regression); parallel ptest (np 1/3/4) unchanged. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- tests/test_0008_snapshot_realsolver.py | 189 +++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 tests/test_0008_snapshot_realsolver.py diff --git a/tests/test_0008_snapshot_realsolver.py b/tests/test_0008_snapshot_realsolver.py new file mode 100644 index 00000000..079aabce --- /dev/null +++ b/tests/test_0008_snapshot_realsolver.py @@ -0,0 +1,189 @@ +"""Real-solver confidence test for the snapshot toolkit. + +Every other snapshot test drives state by hand. This one runs an +actual PETSc solver — AdvDiffusion, which carries an internal +SemiLagrangian DDt (auxiliary projection SNES + nodal trace-back +swarm) — through the stash-and-recover loop. + +Findings this test pins down (all verified, see commit message): + + * The AdvDiffusion solve is bit-deterministic: two independent + identical runs with no snapshot are bit-for-bit equal. + + * restore() recovers the primary solution field T bit-exactly. + + * THE CORE GUARANTEE — a discarded ("regretted") step leaves zero + trace, bit-for-bit, even through real solves: + restore -> regretted solve -> restore -> K solves + is np.array_equal to + restore -> K solves + (B == C, max|d| = 0.0). + + * The only residual is restore's reproducibility *floor* against a + never-snapshotted control (~7e-7 here): restore resyncs fields + through gvec->lvec rather than reproducing the solver-produced + lvec exactly, and the implicit diffusion operator amplifies that + to solver-tolerance level over steps. This is NOT contamination + from the discarded step (proven by B == C above); it is the cost + of round-tripping through the snapshot representation, within + solver tolerance, and consistent with the design intent that + auxiliary solver state is intentionally not captured. + +So the honest production statement: discarding a bad step is +bit-exact even through real solvers; recovering to a never-stashed +control is within solver tolerance. +""" + +import numpy as np +import sympy as sp +import pytest + +pytestmark = [pytest.mark.level_2, pytest.mark.tier_a] + +# Restore-vs-pristine reproducibility floor for this setup (measured). +# The regretted-step guarantee is asserted bit-exact (np.array_equal); +# only the never-stashed-control comparison uses this tolerance. +_RESTORE_FLOOR_ATOL = 1.0e-5 + + +@pytest.fixture(autouse=True) +def _reset(): + import underworld3 as uw + + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + yield + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + + +def _build(): + import underworld3 as uw + + model = uw.get_default_model() + mesh = uw.meshing.StructuredQuadBox( + elementRes=(16, 16), + minCoords=(0.0, 0.0), + maxCoords=(1.0, 1.0), + qdegree=3, + ) + v = uw.discretisation.MeshVariable("U", mesh, mesh.dim, degree=1) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + + adv_diff = uw.systems.AdvDiffusion(mesh, u_Field=T, V_fn=v) + adv_diff.constitutive_model = uw.constitutive_models.DiffusionModel + adv_diff.constitutive_model.Parameters.diffusivity = 1.0 + adv_diff.add_dirichlet_bc(0.0, "Left") + adv_diff.add_dirichlet_bc(0.0, "Right") + v.array[:, 0, 0] = 0.05 + + x, y = mesh.X + T.array = uw.function.evaluate( + sp.sin(sp.pi * x) * sp.sin(sp.pi * y), T.coords + ) + return uw, model, mesh, adv_diff, T + + +def _capture(T): + return np.asarray(T.array[...]).copy() + + +def test_realsolver_restore_recovers_solution_field(): + """snapshot, do a regretted solve, restore — the solution field + itself is recovered exactly.""" + uw, model, mesh, adv_diff, T = _build() + + for _ in range(2): + adv_diff.solve(timestep=1.0e-3) + + pre_T = _capture(T) + snap = model.snapshot() + + adv_diff.solve(timestep=5.0) # absurd Δt: converges, over-diffused + assert not np.allclose(_capture(T), pre_T, atol=1e-8), ( + "the regretted solve was not actually disruptive" + ) + + model.restore(snap) + assert np.array_equal(_capture(T), pre_T), ( + "solution field not exactly recovered after restore" + ) + + +def test_realsolver_regretted_step_leaves_no_trace(): + """THE core guarantee, through a real solver, bit-for-bit. + + B: restore -> K good solves + C: restore -> regretted absurd-Δt solve -> restore + -> same K good solves + + B == C exactly. The discarded step leaves zero trace even though + it ran a real PETSc solve in between. + """ + uw, model, mesh, adv_diff, T = _build() + + for _ in range(3): + adv_diff.solve(timestep=1.0e-3) + snap = model.snapshot() + + # B: restore, then K good solves. + model.restore(snap) + for _ in range(4): + adv_diff.solve(timestep=1.0e-3) + B = _capture(T) + + # C: restore, a regretted solve, restore, the same K good solves. + model.restore(snap) + adv_diff.solve(timestep=5.0) + model.restore(snap) + for _ in range(4): + adv_diff.solve(timestep=1.0e-3) + C = _capture(T) + + assert np.array_equal(B, C), ( + "regretted real solve left a trace after restore — " + f"max abs diff {np.max(np.abs(B - C)):.3e} (expected exactly 0)" + ) + + +def test_realsolver_continuation_within_solver_tolerance(): + """Recovering to a *never-stashed* control is within solver + tolerance (not bit-exact): restore resyncs fields gvec->lvec + rather than reproducing the solver-produced lvec, and the + implicit operator amplifies that to ~tolerance over steps. This + is the documented restore floor, consistent with not capturing + auxiliary solver state by design.""" + uw, model, mesh, adv_diff, T = _build() + + for _ in range(3): + adv_diff.solve(timestep=1.0e-3) + snap = model.snapshot() + + # Control: never snapshotted/restored — straight K solves. + for _ in range(4): + adv_diff.solve(timestep=1.0e-3) + ctrl = _capture(T) + + # Stash path: restore, regretted solve, restore, same K solves. + model.restore(snap) + adv_diff.solve(timestep=5.0) + model.restore(snap) + for _ in range(4): + adv_diff.solve(timestep=1.0e-3) + stash = _capture(T) + + # Not bit-exact vs a never-stashed control (that is the floor), + # but well within solver tolerance and far below the solution + # scale (~0.04 here). + max_diff = float(np.max(np.abs(stash - ctrl))) + assert max_diff < _RESTORE_FLOOR_ATOL, ( + f"continuation drifted beyond the restore floor: " + f"max abs diff {max_diff:.3e} >= {_RESTORE_FLOOR_ATOL:.0e}" + ) + assert not np.array_equal(stash, ctrl), ( + "continuation is unexpectedly bit-exact vs a never-stashed " + "control — if this starts passing, the restore floor has been " + "eliminated and this test should be tightened to np.array_equal" + ) From 14fc00ce40f81ac12683b4a780b0a7532f7bfec6 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Wed, 20 May 2026 21:21:08 +1000 Subject: [PATCH 15/15] review: address Copilot feedback on #195 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four fixes for points raised in Copilot's review of the in-memory snapshot toolkit PR: 1. ddt.py — narrow ``except Exception`` to ``except (ImportError, AttributeError)`` at all five DDt ``_register_state_bearer`` sites. Only the genuine bootstrap cases (import not yet wired during underworld3 init, or older Model without the registry method) get swallowed; real registration bugs now propagate instead of silently masking the silent-state-loss failure mode the design note explicitly warns against. 2. swarm.py — rewrite the contradictory comment in ``Swarm.apply_snapshot_payload`` around the explicit ``_population_generation += 1`` bump. The previous wording said the clear+re-add path had already bumped, which is wrong — neither ``removePoint`` nor the raw ``addNPoints`` call here touches the counter; the explicit bump is what makes a restore visible to downstream consumers as a population change. 3. swarm.py — ``apply_snapshot_payload`` now raises ``SnapshotInvalidatedError`` if the live swarm has user variables that were not in the snapshot. Previously those "extra" vars survived the clear+addNPoints reallocation with uninitialised/stale contents — silent incoherence after restore. Contract is now symmetric with the mesh-variable restore (same variable set on both sides). 4. test_0008_snapshot_realsolver — added comment explaining the ~14× headroom on ``_RESTORE_FLOOR_ATOL = 1e-5`` vs the measured ~7e-7 floor (PETSc/BLAS/MPI variability allowance on CI). Regression: 45 single-rank snapshot+core tests pass; parallel ptest at np-4 still PASS with the new strict-extras check. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/swarm.py | 41 +++++++++++++++++++++----- src/underworld3/systems/ddt.py | 40 +++++++++++++++++++++---- tests/test_0008_snapshot_realsolver.py | 8 ++++- 3 files changed, 76 insertions(+), 13 deletions(-) diff --git a/src/underworld3/swarm.py b/src/underworld3/swarm.py index a0aeef7a..d0408e6e 100644 --- a/src/underworld3/swarm.py +++ b/src/underworld3/swarm.py @@ -4187,16 +4187,43 @@ def apply_snapshot_payload(self, payload: dict) -> None: if hasattr(var, "_canonical_data"): var._canonical_data = None - # The clear+re-add path bumped _population_generation already - # (we don't bump on removePoint, but addNPoints isn't bumped - # either — these are raw PETSc calls). For consistency with - # other mutation paths, bump explicitly here. + # The raw PETSc primitives used above (removePoint loop + + # addNPoints + direct field writes) deliberately bypass + # Swarm.migrate / add_particles_with_coordinates, so they do + # not touch _population_generation. Bump it explicitly here + # for consistency with the other mutation sites — a restore + # IS a population change, downstream consumers should see it. + # (Comment rewritten per Copilot review on #195.) self._population_generation += 1 - # Step 3: write captured per-variable data. - current_vars = {var.clean_name: var for var in self._vars.values()} + # Step 3: write captured per-variable data. Per Copilot + # review on #195, also raise on the inverse direction — + # any user swarm variable on the LIVE swarm that wasn't in + # the snapshot would retain stale/uninitialised contents + # after the clear+addNPoints reallocation, which is exactly + # the silent-incoherence failure we want loud rather than + # quiet. The contract is symmetric with the mesh-variable + # restore: same variable set on both sides. + current_user_vars = { + var.clean_name: var + for var in self._vars.values() + if not var.name.startswith("DMSwarm") + } + captured_names = set(payload["vars"].keys()) + live_names = set(current_user_vars.keys()) + extras = live_names - captured_names + if extras: + raise SnapshotInvalidatedError( + f"swarm {self._snapshot_stable_name()!r}: variables " + f"{sorted(extras)!r} exist on the live swarm but were " + f"not in the snapshot. Restore would leave them with " + f"incoherent data after the population rebuild — add " + f"them before the snapshot was taken, or remove them " + f"before restoring." + ) + for var_clean_name, saved in payload["vars"].items(): - var = current_vars.get(var_clean_name) + var = current_user_vars.get(var_clean_name) if var is None: raise SnapshotInvalidatedError( f"swarm {self._snapshot_stable_name()!r}: variable " diff --git a/src/underworld3/systems/ddt.py b/src/underworld3/systems/ddt.py index cc13612e..bea3d13e 100644 --- a/src/underworld3/systems/ddt.py +++ b/src/underworld3/systems/ddt.py @@ -576,7 +576,13 @@ def __init__( import underworld3 as _uw _uw.get_default_model()._register_state_bearer(self) - except Exception: + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. pass return @@ -947,7 +953,13 @@ def __init__( import underworld3 as _uw _uw.get_default_model()._register_state_bearer(self) - except Exception: + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. pass return @@ -1591,7 +1603,13 @@ def __init__( import underworld3 as _uw _uw.get_default_model()._register_state_bearer(self) - except Exception: + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. pass return @@ -2571,7 +2589,13 @@ def __init__( import underworld3 as _uw _uw.get_default_model()._register_state_bearer(self) - except Exception: + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. pass return @@ -2917,7 +2941,13 @@ def __init__( import underworld3 as _uw _uw.get_default_model()._register_state_bearer(self) - except Exception: + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. pass return diff --git a/tests/test_0008_snapshot_realsolver.py b/tests/test_0008_snapshot_realsolver.py index 079aabce..33afe4a3 100644 --- a/tests/test_0008_snapshot_realsolver.py +++ b/tests/test_0008_snapshot_realsolver.py @@ -40,9 +40,15 @@ pytestmark = [pytest.mark.level_2, pytest.mark.tier_a] -# Restore-vs-pristine reproducibility floor for this setup (measured). +# Restore-vs-pristine reproducibility floor for this setup. # The regretted-step guarantee is asserted bit-exact (np.array_equal); # only the never-stashed-control comparison uses this tolerance. +# +# The measured floor on this setup is ~7e-7 (see commit 3efc31b); +# the assertion threshold is set ~14× looser at 1e-5 deliberately, +# as headroom for variability across PETSc versions / BLAS libs / +# MPI ranks on CI. Do not tighten without measuring on the target +# environments first. (Comment added per Copilot review on #195.) _RESTORE_FLOOR_ATOL = 1.0e-5