diff --git a/docs/developer/design/in_memory_checkpoint_design.md b/docs/developer/design/in_memory_checkpoint_design.md new file mode 100644 index 00000000..10fec6ad --- /dev/null +++ b/docs/developer/design/in_memory_checkpoint_design.md @@ -0,0 +1,504 @@ +# In-memory checkpoint as a general UW3 capability + +Design note. Spun off from the deformable-surface architectural +discussion (2026-05-11) as a self-contained capability that is bounded, +useful in its own right, and not specifically tied to free-surface code. +The companion note that motivated this is +`docs/developer/design/deformable_surface_metronome_design_note.md` (in +the `feature/exp-integrator-freesurface` worktree). + +## Motivation + +**Primary use case: backtrack past unstable timestepping.** A timestepper +hits an instability or a sudden regime change; the cleanest recovery is +to restore the last known-good state and continue with smaller Δt +(or a different scheme). Today this is done ad-hoc by users where it +happens at all. + +**Secondary uses, all sharing the same primitive:** + +- Multi-stage time integration (RK4 between stages — restore to + start-of-step, deform to next stage configuration, sample rate). The + free-surface integrator-zoo work is the immediate case; the same + primitive applies to any multi-stage scheme. +- Adaptive Δt with error estimate (take a step, estimate per-step + error, restore + retry if too large). +- Predictor-corrector probing (try a predictor, check the corrector + residual, fall back if not converging — relevant to the VEP work). +- Regime-change feeling-out (e.g., elastic predictor → check yield + surface → restore + plastic split if violated). + +All five want one thing: at one moment, capture *enough* state that the +system can be put back at any point afterwards as if nothing had +happened. + +## Reframe: same operation as on-disk checkpoint, different backend + +Rather than build a parallel snapshot mechanism, treat in-memory +snapshot as a *backend variant* of the existing checkpoint code. Two +benefits: + +1. **One source of truth for "what is system state."** Whatever set of + (mesh coords, MV DOFs, swarm positions, swarm-var values, algorithm + history, ...) the checkpoint serialiser captures becomes the + contract for in-memory snapshots too. Adding new state-bearers means + adding them once; both backends pick them up. +2. **The existing checkpoint code gets exercised harder and improved + as a side effect.** In-memory roundtrip is a cheap unit test — + every snapshot/restore cycle exercises the serialise/deserialise + paths, which surfaces gaps and bugs that disk-only checkpointing + exposes only at quarterly-test cadence. + +## Two checkpoint paths today (audit-confirmed) + +The audit (read-only investigation, current `main`) confirmed two +distinct on-disk paths in `src/underworld3/`: + +**Path A — PETScSection-based (native checkpoint).** +- `Mesh.write_checkpoint(...)` at `discretisation/discretisation_mesh.py:1892–1953`. +- Uses `dm.sectionView(viewer, subdm)` and `dm.globalVectorView(viewer, subdm, var._gvec)` + through a `PETSc.ViewerHDF5()`. +- Captures mesh topology, deformed coordinates, MV DOF values, and + swarm-variable values via the `_meshVar` proxy. +- Lower-level; bound to the producing DM. + +**Path B — write_timestep (visualization-oriented).** +- `Mesh.write_timestep(...)` at `discretisation/discretisation_mesh.py:1750–1830` + and `Swarm.write_timestep(...)` at `swarm.py:3726–3810`. +- Per-variable HDF5 + XDMF for ParaView; mixes `PETSc.ViewerHDF5()` + with direct `h5py` writes (`swarm.py:1772–1850`). +- More flexible — can be re-loaded at different resolution / + decomposition; bulkier; the user-facing visualisation pipeline. + +**For in-memory snapshot, Path A is the conceptual model.** Restore +goes back to the same DM, so the resolution/decomposition flexibility +of Path B is unneeded. But the in-memory backend will not actually use +the HDF5 Viewer — it will copy section structure + global-vector data +directly into numpy arrays. Same conceptual capture, different +mechanism. + +## What state must be captured + +Audit-informed inventory: + +**Captured today (in Path A):** +- DM topology and section +- Deformed mesh coordinates +- Mesh-variable global vectors (DOF values) — **including `DDt.psi_star`, + which is itself a mesh variable; its DOF data goes through the + section path automatically** +- Swarm-variable values (via mesh-DM proxy fields) + +**NOT captured today, required for full-state in-memory restore:** + +| state | location | priority | +|---|---|---| +| `DDt._dt_history` (variable-Δt BDF history; list of floats) | `systems/ddt.py:386` | **high** | +| `DDt._history_initialised` (bool) | `systems/ddt.py:383` | **high** | +| `DDt._n_solves_completed` (int) | `systems/ddt.py:384` | **high** | +| Binding `DDt instance ↔ psi_star MV(s)` | implicit in `__init__` | high | +| Simulation time, step counter | not in any object today | high | +| Parameter mutation history | `parameters.py:145` (`_history`) | medium | +| `Swarm._mesh_version` | `swarm.py:2421` | medium | +| Solver iteration counts / convergence history | scattered | low | + +**The DDt example is representative, not exceptional.** Algorithm-internal +scalar/list state on Python objects is the architecturally significant +gap — it lives in solver-adjacent Python objects, not in any PETScSection. +The in-memory snapshot must compose PETScSection state + Python-side +mutable state into a single token. The next section addresses how this +composition should be designed in general — not as a per-class bolt-on, +but as a contract that new algorithm-helper classes follow from day one. + +## General serialisation contract for solver-internal state + +The opportunity. Bolting per-class snapshot hooks onto each algorithm +helper as it appears would produce a half-baked system that silently +misses state whenever a new helper is added without a corresponding hook. +DDt today, the next adaptive-Δt controller tomorrow, the gamma estimator +the day after — the pattern recurs. + +This is the right moment to design a general serialisation contract for +solver-internal state, before it leaks into the user API. New +algorithm-helper classes should declare their state slots from day one; +existing classes (DDt, parameter mutation history, solver convergence +bookkeeping) get retrofitted as they're touched. + +**Three options for the contract.** + +(A) **Declarative slots.** Class declares a class-level list of +attribute names that constitute state. +```python +class DDt: + _state_attrs = ('_dt_history', '_history_initialised', + '_n_solves_completed') +``` +Snapshot copies named attrs; restore writes them back. Stringly-typed; +silent breakage when attribute names drift. Least invasive — works as a +mixin on any existing class. + +(B) **Explicit save/restore methods.** Class implements paired methods. +```python +class DDt: + def _save_state(self) -> dict: ... + def _restore_state(self, d: dict) -> None: ... +``` +Most flexible (transform on save, validate on restore). Most boilerplate; +easy to forget to update when adding a new state attribute. Standard +Python pickling pattern (`__getstate__` / `__setstate__`) is essentially +this option. + +(C) **State as a first-class object.** Computation and state are +separated; the State object is a dataclass / pydantic model that +trivially serialises. +```python +@dataclass +class DDtState: + dt_history: list[float] + history_initialised: bool + n_solves_completed: int + psi_star_var_names: list[str] # binding to MVs lives here too + +class DDt: + def __init__(self, ...): + self.state = DDtState(...) + # All mutations go via self.state. +``` +Most self-documenting; enables side benefits beyond serialisation +(deep-copy, equality testing, repr, schema-versioned migrations). Most +invasive — changes how new solver-internal classes are written. + +**Recommendation.** (C) for new code, (B)-style adapters for the small +number of existing classes that need retrofitting (primarily DDt and +parameter-mutation-history). (A) is rejected because the silent-drift +failure mode is exactly the kind of half-baked outcome we're trying to +avoid by designing this now. + +The decision matters because it sets the pattern for every +algorithm-internal class added over the next few years. The cost of +choosing wrong is that snapshot tokens silently miss state for any +class added without proper instrumentation; the cost of choosing right +is that new solver-internal classes get serialisation, deep-copy, and +equality testing automatically. + +**Side benefits of (C) beyond checkpoint.** +- `solver_a.state == solver_b.state` becomes a meaningful comparison — + useful for regression testing of solver-internal behaviour +- `repr(obj.state)` is automatic and useful for debugging +- snapshot tokens compose trivially: `{id(obj): obj.state.copy()}` +- schema versioning is tractable: a `_schema_version` field on each + State object plus a migration registry handles cross-version + compatibility (relevant for on-disk checkpoints that survive UW3 + upgrades) +- the `Snapshottable` interface becomes a one-liner: "has a `.state` + attribute that is a dataclass." + +**Bindings to PETSc-owned objects.** State objects must NOT hold direct +PETSc Vec / DM / MV handles — only stable identifiers (variable names, +mesh names) that can resolve back to the live object on restore. This +keeps tokens plain-Python and avoids the DM-lifecycle hazards +identified in earlier work. + +## A third backend: on-disk full-state snapshot + +The in-memory backend is one storage option. The same capture/restore +machinery supports a slower-but-persistent **on-disk full-state** +backend with no architectural changes — only the serialisation layer +differs. + +This is **distinct from the existing `write_timestep` on-disk path**, +which is selective (per-variable), designed for visualisation, emits +XDMF for ParaView, and does not restore solver-internal state. Three +backends serving three different needs: + +| dimension | existing `write_timestep` | new on-disk full-state | new in-memory | +|---|---|---|---| +| storage | HDF5 + XDMF, per-variable | HDF5, monolithic | dict of numpy arrays | +| selectivity | user picks variables | always full state | always full state | +| Python-side state | not captured | captured | captured | +| restorability | partial (re-init solvers, lose history) | bit-equivalent | bit-equivalent | +| persistence | survives process exit | survives process exit | intra-run only | +| speed | medium | slow (HDF5 I/O) | fast (RAM copy) | +| typical use | visualisation, restart-with-changes | crash recovery, bisection, debugging | RK staging, backtrack, adaptive Δt | + +The existing `write_timestep` continues to serve its role unchanged. +The new mechanism (in-memory + on-disk-full-state) addresses a +different need: faithful state restore for algorithmic uses where +"approximate restart" would silently corrupt the run. + +**Use cases the on-disk full-state backend opens up:** + +- **Crash recovery for long simulations.** Periodic snapshots written + to disk; on restart, restore from the most recent. No lost work + beyond the snapshot interval. +- **Cross-run resumption.** A simulation that ran to $t = T_1$ can be + picked up at $t = T_1$ in a later session, bit-equivalent — including + DDt history and any other algorithm-internal state. +- **Bisection / branching exploration.** Snapshot at a decision point; + try one parameter path; if unsatisfactory, restore and try another. + Powerful for sensitivity studies on long runs. +- **Debugging captures.** Snapshot at a problem point; examine offline; + iterate without re-running the costly setup. + +**What this adds to the design.** Three implications, all bounded: + +1. **Backend abstraction must serialise to bytes from day one.** The + in-memory backend stores numpy arrays directly; the on-disk backend + serialises them to HDF5. Both go through the same `save_*` / `load_*` + protocol on the backend interface — the abstraction is shaped by + needing both. +2. **Schema versioning becomes critical.** In-memory tokens are + short-lived; on-disk tokens may be loaded by a future UW3 version. + Each `State` dataclass carries a `_schema_version`; a migration + registry handles cross-version restore. (The on-disk + `write_timestep` path has no equivalent need because it doesn't + restore solver state.) +3. **Generation-counter semantics need a cross-process branch.** + Within-process: counter check ensures you don't restore an + invalidated snapshot. Across-process (restoring from disk on a + fresh run): the model is being initialised *from* the snapshot, + not restored *to* a previous state — counters are set to the + snapshot's values, no invalidation check needed. Restore code + distinguishes the two paths. + +## API shape + +Full-state always; backend chosen at snapshot time by passing (or +omitting) a path. + +```python +# Backend selection — same capture, different storage layer +token = model.snapshot() # in-memory (default) +model.snapshot(path='step42.snap.h5') # on-disk full-state + +model.restore(token) # in-memory restore +model.restore('step42.snap.h5') # on-disk restore + +# Existing per-variable selective on-disk path is unchanged: +mesh.write_timestep('step42.h5', ...) # visualisation; not full-state +``` + +Backends share a single `Snapshot` structure — only the serialisation +layer differs: + +```python +class Snapshot: + section_state: dict[DM, SectionAndVecData] # PETSc state + python_state: dict[ObjectId, ObjectState] # State dataclasses + generations: dict[ObjectId, int] # within-process invalidation + metadata: dict # sim time, step #, schema version +``` + +The in-memory backend stores `section_state` arrays as numpy buffers +directly; the on-disk backend serialises them to HDF5 datasets. The +Python-side `State` dataclasses serialise to HDF5 attributes (small +scalars / short lists) and groups (arrays). + +Tokens are plain Python / numpy — never PETSc Vec or DM handles. +On-disk "tokens" are paths to single HDF5 files. Either way, the +DM-lifecycle hazards identified in earlier work do not apply. + +## Restore semantics for swarms — rebuild, do not refuse + +**Correction (2026-05-11, post-review).** An earlier draft of this +section proposed using a per-swarm `_population_generation` counter as +an *invalidation gate*: if the counter at restore differs from the +counter at capture, raise rather than restore. That design is wrong. +The whole point of the toolkit is to undo state changes — including +particle motion / migration / repopulation between capture and +restore. Refusing on counter mismatch breaks the central use cases (RK +staging, backtrack-on-instability, adaptive Δt retry — all of which +*will* migrate particles between capture and restore). + +**The correct semantics: restore rebuilds the swarm's particle +population from the snapshot.** Specifically: + +1. Clear the swarm's current local particles. +2. Re-add the captured per-rank coordinates via + `add_particles_with_coordinates(saved_local_coords, migrate=False)`. + The mesh partition is deterministic and unchanged within v1 scope + (mesh-version check still applies), so particles that were local at + capture are local at restore — no migration step needed. +3. Write the captured per-particle variable data back into the + newly-added particles, in their captured order. + +The `Swarm._population_generation` counter is still useful as +*informational metadata* — it can flag in logs / metadata what +happened between capture and restore, it can feed cache invalidation +in other consumers, it can power future optimisations (e.g., a fast +in-place restore when the counter happens to match). But it is **not** +a restore gate. + +Mutation sites where the counter is incremented (current line numbers +re-derived per audit; the design's correctness does not depend on the +exact set as long as we over-bump rather than under-bump): + +| file | site | call | +|---|---|---| +| `swarm.py` | end of `populate()` | covers internal `dm.addNPoints()` calls | +| `swarm.py` | `Swarm.migrate()` after `migration_disabled` early-exit | bumps unconditionally; conservative no-op safe | +| `swarm.py` | after `dm.migrate()` in `add_particles_with_coordinates()` | direct PETSc DM call, not via `Swarm.migrate` | +| `swarm.py` | after `addNPoints` in `add_particles_with_global_coordinates()` | catches `migrate=False` callers | +| `swarm.py` | `advection()` remesh path after re-injection `addNPoints` | recycle-mode reinitialisation | + +`Mesh._mesh_version` is a separate counter on the mesh side. In v1 +the mesh-version mismatch *does* refuse restore — see the next +section. + +## Architectural work required + +In rough dependency order: + +1. **Backend abstraction layer.** Extract the "capture" logic from the + "store" logic in `Mesh.write_checkpoint()`. Introduce a + `CheckpointBackend` protocol (e.g., `save_section`, `save_vector`, + `save_metadata`, `load_section`, `load_vector`, `load_metadata`). + Shaped from day one to support both backends — the in-memory case + is the cheapest test of the abstraction's correctness; the on-disk + case is what locks in the byte-level serialisation contract. The + existing HDF5 path is refactored to fit the protocol; the wedge + is clean because section/vector ops are PETSc abstractions and the + HDF5 coupling lives in the Viewer construction. +2. **Backend implementations.** Two concrete backends: + - `InMemoryBackend` — dict of numpy arrays. Trivial once the + abstraction exists. + - `OnDiskFullStateBackend` — single monolithic HDF5 file. Shares + PETSc-state serialisation with the existing `write_checkpoint` + path (already HDF5); adds Python-state serialisation as HDF5 + attributes/groups. +3. **Adopt the serialisation contract for new solver-internal code.** + Decision: option (C) — state as first-class dataclass (see "General + serialisation contract" section). Any new algorithm-helper class + added from this point on declares its state as a separate + dataclass; the checkpoint mechanism reads `.state` automatically. + Document the pattern in the developer guide; add a check to CI that + new classes in `systems/` and `solvers/` declare `.state`. +4. **Retrofit existing solver-internal classes** to the contract. + In priority order: `DDt` (`systems/ddt.py`), parameter mutation + history (`parameters.py:145`), any solver convergence-tracking + state (audit pending). Each retrofit is small; total bounded by + the number of classes (probably under ten). +5. **Swarm rebuild on restore + informational + `_population_generation` counter.** Snapshot captures per-rank + particle coordinates and per-variable arrays. Restore clears the + current local population and re-adds the captured particles at + their captured coords (see "Restore semantics for swarms" section + above for the corrected design). The counter is bumped at every + identified mutation site for informational use; it is **not** an + invalidation gate. +6. **Schema versioning + migration registry.** Each `State` dataclass + carries a `_schema_version` integer. A central registry maps + `(class, version)` to migration functions that lift older State + data to the current schema. In-memory restore checks version + equality (any mismatch is a programming error since both sides are + the same process); on-disk restore consults the migration registry + and applies migrations in sequence. Without this infrastructure, + on-disk snapshots break the moment any retrofitted class evolves. +7. **`Model.snapshot()` / `Model.restore()` orchestration.** Walks + registered meshes, swarms, MVs, and any object exposing a `.state` + attribute; composes the token; routes to backend (in-memory or + on-disk) based on whether `path=` is given; restores in safe order; + distinguishes within-process restore (counter validation enforced) + from cross-process restore (counters initialised from snapshot). +8. **Tests.** In-memory roundtrip on standard setups (Stokes benchmark, + swarm-bearing setup, DDt-using setup). On-disk roundtrip on the + same setups. Cross-process restore: write to disk in process A, + load in fresh process B, verify the restored model continues + bit-equivalently from the snapshot point. Verify generation + invalidation raises cleanly; verify `.state == .state` after + roundtrip for retrofitted classes. + +## Scope boundaries (NOT in v1) + +- **Mesh adaptation roundtrip — scheduled for v1.2, not a permanent + limitation.** A snapshot taken before a `mesh.adapt()` event in v1 + refuses restore via the `_mesh_version` check, because the captured + DOF arrays are sized for the pre-adapt section and writing them + in-place into the post-adapt DM would corrupt the run. **v1.2 will + replace the refusal with a mesh-rebuild path**: capture enough + topology / section info to destroy the post-adapt DM and rebuild + the pre-adapt one, then write DOFs into the rebuilt DM. The + principle is the same as the swarm rebuild (capture-the-state, + rebuild-on-restore); the implementation is more invasive because + every MeshVariable / Swarm / solver holds references into the DM + and those wrappers need to re-bind. v1's snapshot **captures the + topology / section info even though v1 restore ignores it**, so the + payload is forward-compatible with v1.2 without a schema bump. +- **Replacing the existing `write_timestep` path.** The selective + per-variable on-disk path continues unchanged. The new full-state + on-disk backend is additive and serves a different need (faithful + restore vs visualisation/restart-with-changes). +- **Cross-rank-count restore.** On-disk full-state snapshots written + on N ranks restore on N ranks. Restoring on a different rank count + requires the existing `write_timestep` path's interpolation + machinery and is out of scope for the full-state backend. +- **Lazy / copy-on-write in-memory tokens.** Always-eager copy for + v1. The use cases that drove this (RK staging, single-step + backtrack) all happen on a per-step cadence where eager copy is + affordable. Lazy/COW is an optimisation for later if someone proves + it matters. + +## Open implementation questions + +1. **Does PETSc expose a clean way to copy section + vector state + directly into numpy without going through a Viewer?** The PETSc + `Vec.getArray()` / `Section.getDof()` low-level APIs should + suffice, but it needs a quick prototype to confirm. Alternative: + `PETSc.Viewer.Type.MEMORY` may exist but support for DM operations + is uncertain. +2. **How does `Model.snapshot()` discover state-bearing objects?** + With the contract decision made (option C — `.state` dataclass), + discovery options narrow to: (a) walk Model's registered solvers + and recurse into anything with a `.state` attr; (b) explicit + registration at construction time. (a) is more ergonomic but + relies on solvers maintaining proper registration with the Model. + (b) is more explicit but requires every state-bearing object to + know about Model. (a) is probably right; verify the existing + solver-Model relationship can support recursive `.state` discovery. +3. **What does `Model.restore(token)` order look like?** Mesh + coords first (since MV DOFs are tied to mesh layout), then MV + DOFs, then swarm positions + migrate, then swarm-var values, then + DDt and Python-side state, then generation counters validated last. + Confirm this order is correct; particularly whether MV restoration + needs the mesh to be at restored coords first. +4. **Memory cost on realistic setups.** For a typical coupled-physics + run (mesh + swarm + several MVs + DDt history), what's the + per-snapshot byte cost? Drives the answer to "can users keep N + snapshots in memory without thinking about it." Probably bounded + by one-Stokes-solve worth of memory, but worth measuring early. + +## Status + +Audit complete. Design pending review. Implementation has not started. +The work is bounded (no open-ended research items in any of the seven +architectural-work items above) and decomposes into commits that can +land sequentially: backend abstraction → InMemoryBackend → Python-side +registration → swarm generation counter → Model orchestration → tests. + +Expected size: around four weeks of careful work, dominated by: +- the backend-abstraction refactor (item 1), which touches existing + checkpoint code other things depend on, and must be shaped to + support both backends from day one; +- the contract retrofit (item 4), mechanically small per class but + needing careful audit so no state is silently missed; +- the on-disk backend (item 2b) and schema versioning (item 6), + together adding around a week beyond the in-memory-only scope but + unlocking the cross-run / crash-recovery / bisection use cases. + +The contract decision (item 3) and the schema-versioning decision +(item 6) are the highest-leverage pieces — they set the pattern for +every algorithm-internal class added over the next few years and the +durability guarantee for every on-disk full-state snapshot. Worth +getting right even if they slow the immediate work, because the +alternative is years of half-baked snapshots that silently miss state +or break across UW3 versions. + +## Cross-references + +- `docs/developer/design/deformable_surface_metronome_design_note.md` + (in `feature/exp-integrator-freesurface` worktree) — the parent + design discussion that motivated spinning this off as a separate + capability. +- `publications/free-surface-paper/integrator_zoo_supplementary.md` — + the empirical work that exposed the need for snapshot/restore as + part of multi-stage time integration. diff --git a/docs/developer/guides/state-as-dataclass.md b/docs/developer/guides/state-as-dataclass.md new file mode 100644 index 00000000..37f34718 --- /dev/null +++ b/docs/developer/guides/state-as-dataclass.md @@ -0,0 +1,212 @@ +# State-as-dataclass — the snapshot contract for solver-internal helpers + +When adding a new solver-internal helper class (a time-derivative +manager, a controller, a convergence tracker, anything carrying +mutable evolution state in Python attributes), you should declare its +mutable state as a dataclass exposed via a `.state` attribute. The +unitary snapshot / restore toolkit +([design note](../design/in_memory_checkpoint_design.md), +implementation in `src/underworld3/checkpoint/`) discovers state +through this attribute automatically — no per-class registration of +which attributes are "state" required. + +## Why a contract + +Without a uniform contract, every helper invents its own way of +persisting mutable state — or doesn't, and silently loses information +across snapshot / restore. The audit that motivated this work found +several classes (DDt, parameter mutation history, swarm population +counter, ...) each with bespoke state-tracking and no path to +snapshot. The contract makes the obligation explicit: if your class +has mutable state, declare it as a dataclass, end of story. + +Side benefits of the contract beyond snapshot: +- `obj_a.state == obj_b.state` becomes a meaningful comparison — useful + for regression-testing solver-internal behaviour. +- `repr(obj.state)` is automatic and useful for debugging. +- Schema versioning is tractable: a `_schema_version` field on each + State dataclass plus a migration registry (v1.1) handles + cross-version compatibility for on-disk snapshots. + +## The interface + +```python +from underworld3.checkpoint import SnapshottableState + +# Inherit from SnapshottableState — gives you the _schema_version field. +@dataclass +class MyHelperState(SnapshottableState): + # _schema_version: int = 1 # inherited; override on schema change + counter: int = 0 + history: list[float] = field(default_factory=list) + config_name: str = "" +``` + +The host class exposes the dataclass via a `.state` attribute +(property is fine; stored attribute is fine). + +```python +class MyHelper(uw_object): # uw_object gives you instance_number + def __init__(self, ...): + super().__init__() + # ... set up your state ... + + # Self-register so Model.snapshot() discovers you. + try: + import underworld3 as _uw + _uw.get_default_model()._register_state_bearer(self) + except Exception: + pass + + @property + def state(self) -> MyHelperState: + return MyHelperState( + counter=self._counter, + history=list(self._history), + config_name=self._config_name, + ) + + @state.setter + def state(self, s: MyHelperState) -> None: + if s._schema_version != MyHelperState._schema_version: + raise ValueError("schema mismatch") + self._counter = s.counter + self._history = list(s.history) + self._config_name = s.config_name + # If your class has *derived* state (caches, downstream + # coefficients, ...) recompute it here so post-restore reads + # are consistent without waiting for the next solve. +``` + +The protocol check is structural: any object exposing a `.state` +attribute that is a `SnapshottableState` instance is `Snapshottable`. + +## Option (B) vs (C): stored vs derived dataclass + +The design note discusses two implementation styles. Both satisfy the +protocol; pick based on whether the class is new or being retrofitted. + +### Option (C): state is the authoritative store. *Prefer for new code.* + +```python +class MyNewHelper: + def __init__(self, ...): + self.state = MyHelperState(counter=0, history=[]) + + def step(self, dx): + self.state.counter += 1 + self.state.history.append(dx) +``` + +Every mutation site reads/writes `self.state.` directly. The +dataclass *is* the storage. Self-documenting; mistakes (forgetting to +add a new field to the state object) are obvious because adding state +involves adding a field to the dataclass. + +### Option (B): state is a derived view over private attrs. *Use for retrofits.* + +```python +class MyExistingHelper: + def __init__(self, ...): + self._counter = 0 + self._history = [] + # ... legacy private attrs throughout the class ... + + @property + def state(self) -> MyHelperState: + return MyHelperState(counter=self._counter, history=list(self._history)) + + @state.setter + def state(self, s): + self._counter = s.counter + self._history = list(s.history) +``` + +The existing private attrs stay as-is; `.state` is a façade. Less +invasive than option (C) because the existing call sites don't +change — only the new accessor is added. + +The five `DDt` flavors in `src/underworld3/systems/ddt.py` (`Symbolic`, +`Eulerian`, `SemiLagrangian`, `Lagrangian`, `Lagrangian_Swarm`) use +option (B) for this reason: they predate the contract, and rewriting +every mutation site would be churn without behavior change. + +## What goes in the State dataclass + +**Yes:** +- Mutable evolution-tracking state — anything that changes after + construction and affects subsequent behaviour. Counters, history + buffers, current-step values, mutation logs. +- Names / stable IDs of bound objects (mesh-variable names, swarm + names) — restore uses these to verify the binding still holds. +- Configuration flags whose state is meaningful at snapshot time + (`with_forcing_history`, `recycle_rate`, ...) — they help restore + detect a mid-run reconfiguration that would invalidate the snapshot. + +**No:** +- Live PETSc Vec, DM, or solver handles. Tokens must stay plain + Python / numpy so they survive DM lifecycle changes (and so v1.1's + on-disk serialisation works). +- Constructor-time arguments that never change (mesh reference, + variable type, polynomial degree). Re-derive on restore from the + surviving wrapper. +- Bulk numerical data like mesh-variable DOFs or swarm-variable + arrays. Those travel via the dedicated mesh-var / swarm-var + snapshot paths. +- Caches, derived coefficients, anything you can recompute from + primary state in the `.state` setter. + +## Schema versioning + +`_schema_version` exists for a v1.1 / v1.2 feature: on-disk snapshots +that survive across UW3 versions. When you change the shape of a +State dataclass (rename a field, add a required field, change a +type), bump the version and add a migration entry. Within a single +process (v1 in-memory only), the version is checked for strict +equality — any mismatch is a programming error. + +The migration registry itself is v1.1 work (item 6 of the design +note). For now: define `_schema_version: int = 1` on every State +dataclass and don't worry about migrations until on-disk lands. + +## Where to put your State dataclass + +Next to the host class. `DDtSymbolicState` lives next to +`Symbolic` in `src/underworld3/systems/ddt.py`. This keeps the +dataclass and the class it describes co-located; changes to one show +up in the diff against the other. + +## Testing + +A typical snapshot test for a state-bearing class: + +```python +def test_my_helper_roundtrip(): + uw.reset_default_model() + model = uw.get_default_model() + h = MyHelper(...) + + # Advance state. + h.step(0.1) + h.step(0.2) + state_pre = h.state + + snap = model.snapshot() + + # Mutate. + h.step(0.5) + + model.restore(snap) + + # Verify primary state recovered. + assert h.state == state_pre +``` + +The dataclass `__eq__` makes the final assertion a one-liner. + +## Related + +- [Design note](../design/in_memory_checkpoint_design.md) — full design rationale, scope, and roadmap. +- `src/underworld3/checkpoint/state.py` — protocol definitions. +- `src/underworld3/checkpoint/snapshot.py` — capture / restore orchestration. +- `src/underworld3/systems/ddt.py` — five working examples (option (B) retrofits). diff --git a/src/underworld3/__init__.py b/src/underworld3/__init__.py index ad4b80e6..d37bcb20 100644 --- a/src/underworld3/__init__.py +++ b/src/underworld3/__init__.py @@ -207,6 +207,7 @@ def view(): import underworld3.parameters import underworld3.materials import underworld3.discretisation.persistence +import underworld3.checkpoint from .model import ( Model, diff --git a/src/underworld3/checkpoint/__init__.py b/src/underworld3/checkpoint/__init__.py new file mode 100644 index 00000000..fcde6412 --- /dev/null +++ b/src/underworld3/checkpoint/__init__.py @@ -0,0 +1,40 @@ +"""Unitary in-memory (and, later, on-disk) snapshot toolkit. + +The first true unitary checkpoint in Underworld3 — captures enough state +that a Model can be put back exactly as it was, suitable for backtrack on +failure, multi-stage time integration, adaptive-Δt retry, and crash +recovery. + +Distinct from the existing per-variable ``write_timestep`` / +``read_timestep`` path, which serves visualisation and partial restart. +That path stays in service of its existing role. + +See ``docs/developer/design/in_memory_checkpoint_design.md`` for the +design rationale, scope, and roadmap. In v1 (this code), only an +in-memory backend is implemented and only mesh + mesh-variable state is +captured. Subsequent PRs add swarm coverage, solver-internal Python +state (DDt history, parameter mutation history), an on-disk full-state +backend, and schema versioning across UW3 releases. +""" + +from .backend import CheckpointBackend, InMemoryBackend +from .snapshot import ( + SNAPSHOT_SCHEMA_VERSION, + Snapshot, + SnapshotInvalidatedError, + snapshot, + restore, +) +from .state import Snapshottable, SnapshottableState + +__all__ = [ + "CheckpointBackend", + "InMemoryBackend", + "SNAPSHOT_SCHEMA_VERSION", + "Snapshot", + "SnapshotInvalidatedError", + "snapshot", + "restore", + "Snapshottable", + "SnapshottableState", +] diff --git a/src/underworld3/checkpoint/backend.py b/src/underworld3/checkpoint/backend.py new file mode 100644 index 00000000..271e486d --- /dev/null +++ b/src/underworld3/checkpoint/backend.py @@ -0,0 +1,80 @@ +"""Storage protocol for snapshot tokens. + +The protocol is shaped from day one to support both an in-memory backend +(dict of numpy arrays, v1) and an on-disk full-state backend (HDF5, +v1.1). Per the design note, the in-memory backend is the cheapest +correctness test of the abstraction; the on-disk backend locks in the +byte-level serialisation contract. +""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +import numpy as np + + +@runtime_checkable +class CheckpointBackend(Protocol): + """Backing-store interface for :class:`underworld3.checkpoint.Snapshot`. + + Implementations + --------------- + - :class:`InMemoryBackend` — v1; numpy arrays held in process memory. + - ``OnDiskFullStateBackend`` — v1.1; single monolithic HDF5 file. + + Vectors are bulk numerical data; metadata is small scalars / dicts / + lists describing structure and provenance. + """ + + def save_vector(self, key: str, array: np.ndarray) -> None: ... + + def load_vector(self, key: str) -> np.ndarray: ... + + def save_metadata(self, key: str, value: Any) -> None: ... + + def load_metadata(self, key: str) -> Any: ... + + def list_vectors(self) -> list[str]: ... + + def list_metadata(self) -> list[str]: ... + + +class InMemoryBackend: + """Snapshot storage in process memory. + + Eager-copy on both ``save_vector`` and ``load_vector`` per the v1 + scope-boundary (no lazy / copy-on-write semantics). Per-snapshot + byte cost is the sum of captured vector sizes — expected to be + bounded by one Stokes solve's working memory for typical setups. + """ + + def __init__(self) -> None: + self._vectors: dict[str, np.ndarray] = {} + self._metadata: dict[str, Any] = {} + + def save_vector(self, key: str, array: np.ndarray) -> None: + if key in self._vectors: + raise KeyError(f"vector key already present in snapshot: {key!r}") + self._vectors[key] = np.asarray(array).copy() + + def load_vector(self, key: str) -> np.ndarray: + if key not in self._vectors: + raise KeyError(f"vector key not in snapshot: {key!r}") + return self._vectors[key].copy() + + def save_metadata(self, key: str, value: Any) -> None: + if key in self._metadata: + raise KeyError(f"metadata key already present in snapshot: {key!r}") + self._metadata[key] = value + + def load_metadata(self, key: str) -> Any: + if key not in self._metadata: + raise KeyError(f"metadata key not in snapshot: {key!r}") + return self._metadata[key] + + def list_vectors(self) -> list[str]: + return list(self._vectors.keys()) + + def list_metadata(self) -> list[str]: + return list(self._metadata.keys()) diff --git a/src/underworld3/checkpoint/snapshot.py b/src/underworld3/checkpoint/snapshot.py new file mode 100644 index 00000000..eff45bac --- /dev/null +++ b/src/underworld3/checkpoint/snapshot.py @@ -0,0 +1,386 @@ +"""Unitary state capture and restore. + +A :class:`Snapshot` is a plain-Python token holding numpy data and small +metadata. It contains no PETSc Vec or DM handles, so it survives object +lifecycle changes within a process. Within a process, a snapshot can be +restored back onto the same :class:`underworld3.Model` instance; across +processes (v1.1, on-disk backend) the model is initialised from the +snapshot rather than restored to a previous state. + +Forward-compatibility for v1.2 (mesh-adapt rebuild on restore) +--------------------------------------------------------------- +The snapshot module is structured so that v1.2 can replace the +``_mesh_version`` refusal with a true mesh-DM rebuild without touching +this module. Two principles support this: + +- **Capture by stable name, not by Python id.** Meshes are keyed by + ``mesh.name``, swarms by ``f"swarm_{instance_number}"``. Within a + single process this is overkill (object id would work); but it + trivialises v1.1 cross-process restore and v1.2 mesh-rebuild (where + the wrapper object survives but its DM is destroyed and recreated). + +- **Wrappers, not the snapshot module, decide how to apply a payload.** + ``Mesh.apply_snapshot_payload()`` and ``Swarm.apply_snapshot_payload()`` + receive a self-contained dict and decide what to do with it. v1 + implementations are in-place writes; v1.2's mesh implementation can + inspect the topology slot of the payload (left ``None`` by v1 + capture) and rebuild the DM if needed, without any change to the + capture / orchestration here. + +This module implements the v1 scope: mesh coordinates and mesh-variable +DOFs, plus swarm positions and user swarm-variable data with +rebuild-on-restore semantics. Solver-internal Python state, on-disk +backend, schema versioning, mesh-DM rebuild, and cross-process restore +are scheduled for follow-up PRs per the design note. +""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass, field +from typing import Any, Optional + +import numpy as np + +from .backend import CheckpointBackend, InMemoryBackend +from .state import SnapshottableState + + +SNAPSHOT_SCHEMA_VERSION = 1 + + +class SnapshotInvalidatedError(RuntimeError): + """Raised when a snapshot can no longer be restored faithfully. + + Triggers in v1: + + - A captured mesh / swarm / variable name is no longer present on + the target :class:`underworld3.Model`. + - Mesh ``_mesh_version`` differs from the snapshot's captured + value. v1 treats this as fatal because the captured DOF arrays + are sized for the pre-adapt section. **v1.2 will replace this + refusal with a mesh-rebuild path** on the same principle as the + swarm rebuild — see the design note's mesh-adapt scope section. + + Notably **not** a trigger: swarm population mutation + (populate / migrate / add_particles / remesh) between capture and + restore. The swarm restore path *rebuilds* the local particle + population from the snapshot, so intervening mutations are exactly + what restore is for. The ``_population_generation`` counter on the + swarm is informational, not a restore gate. + """ + + +def _swarm_stable_name(swarm) -> str: + """Per-process stable name for a swarm. Uses uw_object instance number.""" + return f"swarm_{swarm.instance_number}" + + +@dataclass +class Snapshot: + """Unitary state token. + + Produced by :func:`snapshot`; consumed by :func:`restore`. Holds a + backend (where the bulk arrays live) plus per-model bookkeeping — + which meshes and swarms were captured, in what order, with what + variable sets. + + Attributes + ---------- + backend + Where the captured arrays and small metadata live. + schema_version + Snapshot file-format version. Restore refuses on mismatch in + v1; v1.1's migration registry will lift older versions to the + current schema for on-disk restore only. + mesh_names + Capture order of mesh names. ``mesh.name`` is the stable key. + mesh_versions + Per-mesh ``_mesh_version`` at the moment of capture. v1 + compares strictly; v1.2 will rebuild on mismatch. + meshvar_names + Mapping ``mesh_name → [var clean_name, ...]``. + swarm_names + Capture order of swarm stable names + (``f"swarm_{instance_number}"``). + swarm_mesh_names + Mapping ``swarm_name → mesh_name`` so restore can verify the + swarm's parent mesh is still the captured one. + swarm_generations + Captured ``_population_generation`` per swarm — informational + metadata; *not* a restore gate. Useful for logs and debugging + ("this snapshot was taken at generation 7; the current swarm + is at 12"). + swarmvar_names + Mapping ``swarm_name → [user-var clean_name, ...]``. Internal + DMSwarm-prefixed variables are filtered out. + metadata + Free-form user/system metadata (simulation time, step counter, + ...). Not load-bearing for restore correctness. + """ + + backend: CheckpointBackend + schema_version: int = SNAPSHOT_SCHEMA_VERSION + mesh_names: list[str] = field(default_factory=list) + mesh_versions: dict[str, int] = field(default_factory=dict) + meshvar_names: dict[str, list[str]] = field(default_factory=dict) + swarm_names: list[str] = field(default_factory=list) + swarm_mesh_names: dict[str, str] = field(default_factory=dict) + swarm_generations: dict[str, int] = field(default_factory=dict) + swarmvar_names: dict[str, list[str]] = field(default_factory=dict) + # State-bearer captures: list of (stable_key, state_dataclass). + # stable_key is f"{type(obj).__name__}_{obj.instance_number}", matched + # at restore against the same key derived from currently-registered + # state-bearers. List preserves capture order — informational only, + # since lookup is by key. + state_bearers: list = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +# ----- Backend key conventions ----- + +def _mesh_coords_key(mesh_name: str) -> str: + return f"mesh:{mesh_name}:coords" + + +def _meshvar_key(mesh_name: str, var_clean_name: str) -> str: + return f"mesh:{mesh_name}:var:{var_clean_name}:gvec" + + +def _swarm_coords_key(swarm_name: str) -> str: + return f"swarm:{swarm_name}:coords" + + +def _swarmvar_key(swarm_name: str, var_clean_name: str) -> str: + return f"swarm:{swarm_name}:var:{var_clean_name}:data" + + +def _is_internal_swarmvar(var_name: str) -> bool: + """Filter PETSc-managed internal swarm variables from user capture. + + ``DMSwarmPIC_coor`` is captured separately via the particle-coords + path. ``DMSwarm_X0`` and ``DMSwarm_remeshed`` carry recycle-related + bookkeeping that is regenerated on next solve and is out of scope + for v1 capture. + """ + return var_name.startswith("DMSwarm") + + +# ----- Capture (orchestration) ----- + +def snapshot(model, *, path: Optional[str] = None) -> Snapshot: + """Capture a unitary snapshot of the model's current state. + + Captures, in v1: each registered mesh's deformed coordinates and + every mesh-variable's global-vector DOFs; each registered swarm's + per-rank particle coordinates and user swarm-variable arrays. + + Pass ``path=...`` once the v1.1 on-disk backend lands. v1 raises + ``NotImplementedError``. + + See ``docs/developer/design/in_memory_checkpoint_design.md`` for + the design rationale and scope boundaries. + """ + if path is not None: + raise NotImplementedError( + "on-disk full-state snapshot is scheduled for v1.1; " + "v1 supports the in-memory backend only" + ) + + snap = Snapshot(backend=InMemoryBackend()) + for mesh in list(model._meshes.values()): + _capture_mesh(snap, mesh) + for swarm in list(model._swarms.values()): + _capture_swarm(snap, swarm) + for obj in list(model._state_bearers): + _capture_state_bearer(snap, obj) + return snap + + +def _capture_mesh(snap: Snapshot, mesh) -> None: + payload = mesh.snapshot_payload() + name = payload["name"] + if name in snap.mesh_names: + raise RuntimeError( + f"duplicate mesh name {name!r} in snapshot capture; mesh names " + f"must be unique within a Model" + ) + snap.mesh_names.append(name) + snap.mesh_versions[name] = payload["mesh_version"] + + snap.backend.save_vector(_mesh_coords_key(name), payload["coords"]) + + var_names: list[str] = [] + for var_clean_name, gvec_array in payload["vars"].items(): + snap.backend.save_vector(_meshvar_key(name, var_clean_name), gvec_array) + var_names.append(var_clean_name) + snap.meshvar_names[name] = var_names + + +def _state_bearer_key(obj) -> str: + """Stable per-process key for a Snapshottable. ``instance_number`` + comes from ``uw_object`` and is unique across the run.""" + return f"{type(obj).__name__}_{obj.instance_number}" + + +def _capture_state_bearer(snap: Snapshot, obj) -> None: + """Pull ``obj.state`` and store a deep copy. + + Deep copy ensures later mutations on the live state-bearer don't + leak into the captured token. The dataclass itself is the storage + here (no separate backend.save_vector call) because state + dataclasses are small Python objects, not bulk numerical arrays. + v1.1's on-disk backend will route the dataclass through the + backend (HDF5 attrs/groups); v1 holds them in the in-memory Snapshot + directly. + """ + state = obj.state + if not isinstance(state, SnapshottableState): + raise TypeError( + f"{type(obj).__name__}.state must be a SnapshottableState, " + f"got {type(state).__name__}" + ) + snap.state_bearers.append((_state_bearer_key(obj), copy.deepcopy(state))) + + +def _capture_swarm(snap: Snapshot, swarm) -> None: + payload = swarm.snapshot_payload() + name = payload["name"] + if name in snap.swarm_names: + raise RuntimeError( + f"duplicate swarm name {name!r} in snapshot capture" + ) + snap.swarm_names.append(name) + snap.swarm_mesh_names[name] = payload["mesh_name"] + snap.swarm_generations[name] = payload["population_generation"] + + snap.backend.save_vector(_swarm_coords_key(name), payload["coords"]) + + var_names: list[str] = [] + for var_clean_name, data in payload["vars"].items(): + snap.backend.save_vector(_swarmvar_key(name, var_clean_name), data) + var_names.append(var_clean_name) + snap.swarmvar_names[name] = var_names + + +# ----- Restore (orchestration) ----- + +def restore(model, snap: Snapshot) -> None: + """Restore the model from a snapshot. + + Mesh restore in v1 writes captured coords + DOFs back in place. If + the mesh's ``_mesh_version`` has moved since capture, restore + raises :class:`SnapshotInvalidatedError` — this becomes a rebuild + path in v1.2. + + Swarm restore *rebuilds* the local particle population: clears + current particles, re-adds at captured coords, writes captured + per-variable data back in order. This is the rebuild-on-restore + semantics described in the design note's "Restore semantics for + swarms" section — restore is precisely *for* the case where + particles have moved / been added / been removed since capture. + + Parameters + ---------- + model + The :class:`underworld3.Model` to restore. Must be the same + instance the snapshot came from (within-process restore). + snap + Token returned by :func:`snapshot`. + + Raises + ------ + SnapshotInvalidatedError + Captured mesh / swarm / variable is no longer registered on + the model, or mesh ``_mesh_version`` has moved since capture + (mesh-adapt is v1.2 scope). + TypeError + ``snap`` is not a :class:`Snapshot`. + """ + if not isinstance(snap, Snapshot): + raise TypeError( + f"expected underworld3.checkpoint.Snapshot, got {type(snap).__name__}" + ) + if snap.schema_version != SNAPSHOT_SCHEMA_VERSION: + raise SnapshotInvalidatedError( + f"snapshot schema version {snap.schema_version} does not match " + f"current {SNAPSHOT_SCHEMA_VERSION}; on-disk migration is v1.1" + ) + + meshes_by_name = {m.name: m for m in model._meshes.values()} + swarms_by_name = {_swarm_stable_name(s): s for s in model._swarms.values()} + + for mesh_name in snap.mesh_names: + mesh = meshes_by_name.get(mesh_name) + if mesh is None: + raise SnapshotInvalidatedError( + f"mesh {mesh_name!r} from snapshot is not registered on " + f"this Model; within-process restore requires the originating " + f"Model" + ) + payload = _build_mesh_payload(snap, mesh_name) + mesh.apply_snapshot_payload(payload) + + for swarm_name in snap.swarm_names: + swarm = swarms_by_name.get(swarm_name) + if swarm is None: + raise SnapshotInvalidatedError( + f"swarm {swarm_name!r} from snapshot is not registered on " + f"this Model" + ) + expected_mesh_name = snap.swarm_mesh_names[swarm_name] + if swarm.mesh.name != expected_mesh_name: + raise SnapshotInvalidatedError( + f"swarm {swarm_name!r} parent mesh changed from " + f"{expected_mesh_name!r} to {swarm.mesh.name!r} since " + f"snapshot" + ) + payload = _build_swarm_payload(snap, swarm_name) + swarm.apply_snapshot_payload(payload) + + if snap.state_bearers: + bearers_by_key = { + _state_bearer_key(o): o for o in list(model._state_bearers) + } + for key, captured_state in snap.state_bearers: + obj = bearers_by_key.get(key) + if obj is None: + raise SnapshotInvalidatedError( + f"state-bearer {key!r} from snapshot is not registered " + f"on this Model; restore requires the originating " + f"Model" + ) + obj.state = copy.deepcopy(captured_state) + + +def _build_mesh_payload(snap: Snapshot, mesh_name: str) -> dict: + return { + "name": mesh_name, + "captured_mesh_version": snap.mesh_versions[mesh_name], + "coords": snap.backend.load_vector(_mesh_coords_key(mesh_name)), + # Topology is None in v1; v1.2 mesh-rebuild path will populate + # this slot (e.g., section view data) without bumping the + # schema version, because v1 reads ignore the key. + "topology": None, + "vars": { + var_clean_name: snap.backend.load_vector( + _meshvar_key(mesh_name, var_clean_name) + ) + for var_clean_name in snap.meshvar_names[mesh_name] + }, + } + + +def _build_swarm_payload(snap: Snapshot, swarm_name: str) -> dict: + return { + "name": swarm_name, + "mesh_name": snap.swarm_mesh_names[swarm_name], + "captured_population_generation": snap.swarm_generations[swarm_name], + "coords": snap.backend.load_vector(_swarm_coords_key(swarm_name)), + "vars": { + var_clean_name: snap.backend.load_vector( + _swarmvar_key(swarm_name, var_clean_name) + ) + for var_clean_name in snap.swarmvar_names[swarm_name] + }, + } diff --git a/src/underworld3/checkpoint/state.py b/src/underworld3/checkpoint/state.py new file mode 100644 index 00000000..646b8b2a --- /dev/null +++ b/src/underworld3/checkpoint/state.py @@ -0,0 +1,77 @@ +"""State-as-dataclass contract for snapshot capture/restore. + +Per the design note's ``General serialisation contract for solver-internal +state'' section, the chosen pattern for new solver-internal classes is +**option (C): state as a first-class dataclass**. A class declares a +``state`` attribute that is a dataclass; the snapshot mechanism reads it +automatically. + +For existing classes being retrofitted (DDt today, parameter mutation +history next, any solver convergence-tracking state after that), the +design recommends **option (B)-style adapters**: a ``state`` property +that *derives* a dataclass from the existing private attributes, and a +matching setter that writes the dataclass values back. The dataclass is +not the authoritative store — the private attrs are — but the snapshot +mechanism only sees the dataclass, so the snapshot/restore contract is +the same as for option (C). + +This module defines: + +- :class:`SnapshottableState` — the dataclass base every state object + should inherit / mimic. Carries the ``_schema_version`` field that + drives the v1.1 on-disk migration registry. +- :class:`Snapshottable` — a structural protocol; an object is + Snapshottable if it exposes a ``state`` attribute that is a + :class:`SnapshottableState`. Used by ``Model.snapshot()`` discovery. + +The actual State dataclasses for specific classes (``DDtState``, +``ParameterRegistryState``) live next to those classes in the systems +/ utilities packages. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Protocol, runtime_checkable + + +@dataclass +class SnapshottableState: + """Base for every per-class State dataclass. + + Subclasses add their own fields. The single mandatory field is + ``_schema_version`` — an integer that v1.1's on-disk migration + registry uses to lift older snapshots to the current schema. In + v1 (in-memory only) the version is checked for strict equality; + any mismatch is a programming error since capture and restore + happen in the same process. + """ + + _schema_version: int = 1 + + +@runtime_checkable +class Snapshottable(Protocol): + """Structural protocol for state-bearing objects. + + An object is Snapshottable if:: + + obj.state # readable + obj.state = obj.state # writable + isinstance(obj.state, SnapshottableState) + + The snapshot mechanism uses :attr:`state` to capture and restore. + Implementations choose between: + + - **Option (C), authoritative dataclass.** ``state`` is a stored + attribute holding the dataclass; every mutation site on the class + writes ``self.state.`` directly. Best for new code. + + - **Option (B), derived dataclass.** ``state`` is a property that + builds the dataclass from existing private attrs on each read, + and the setter writes attrs back. Best for retrofits of existing + classes (DDt, ParameterRegistry). + """ + + @property + def state(self) -> SnapshottableState: ... diff --git a/src/underworld3/discretisation/discretisation_mesh.py b/src/underworld3/discretisation/discretisation_mesh.py index bf3e4075..515c7511 100644 --- a/src/underworld3/discretisation/discretisation_mesh.py +++ b/src/underworld3/discretisation/discretisation_mesh.py @@ -2567,6 +2567,107 @@ def write_checkpoint( uw.mpi.barrier() # should not be required viewer.destroy() + # ----- Unitary snapshot / restore ----- + # + # See ``src/underworld3/checkpoint/snapshot.py`` and + # ``docs/developer/design/in_memory_checkpoint_design.md``. v1 + # captures deformed coords + per-MV global-vector DOFs; v1.2 will + # add topology / section capture so the DM can be rebuilt on + # restore after ``mesh.adapt()``. + + def snapshot_payload(self) -> dict: + """Return a self-contained dict describing this mesh's state. + + The returned dict is consumed by + :mod:`underworld3.checkpoint.snapshot` capture. Keys: + + - ``name``: stable string identifier for the mesh. + - ``mesh_version``: current ``_mesh_version`` integer. + - ``coords``: deformed mesh coordinates (numpy array). + - ``vars``: ``{var.clean_name: gvec_array.copy()}`` for every + mesh variable on this mesh. + + v1.2 will additionally populate a ``topology`` key with + section / DM-topology data sufficient to rebuild the DM on + restore. + """ + coords = numpy.asarray(self.X.coords).copy() + var_arrays: Dict[str, numpy.ndarray] = {} + for var in self.vars.values(): + var._sync_lvec_to_gvec() + # Variables created but never touched have _gvec=None (lazy + # allocation in MeshVariable). They carry no data so they + # contribute nothing to the snapshot — skip cleanly. + if var._gvec is None: + continue + var_arrays[var.clean_name] = numpy.asarray(var._gvec.array).copy() + return { + "name": self.name, + "mesh_version": int(getattr(self, "_mesh_version", 0)), + "coords": coords, + "vars": var_arrays, + } + + def apply_snapshot_payload(self, payload: dict) -> None: + """Restore this mesh from a payload produced by :meth:`snapshot_payload`. + + v1 implementation writes coordinates and per-variable DOFs + back in place. The captured DOF arrays must match the current + section, which means ``_mesh_version`` must equal the captured + value — mesh-adapt during the interval would have resized the + section and is detected as a v1 refusal here. + + v1.2 will replace the ``_mesh_version`` refusal with a + rebuild-from-payload path: destroy the current DM, rebuild + from ``payload["topology"]``, allocate vectors, write DOFs, + and re-bind MeshVariable / Swarm wrappers. The interface stays + the same; only this method's body changes. + """ + from underworld3.checkpoint.snapshot import SnapshotInvalidatedError + + current_version = int(getattr(self, "_mesh_version", 0)) + captured_version = int(payload["captured_mesh_version"]) + if current_version != captured_version: + raise SnapshotInvalidatedError( + f"mesh {self.name!r}: _mesh_version moved from " + f"{captured_version} to {current_version} since snapshot. " + f"mesh.adapt() rebuild on restore is scheduled for v1.2; " + f"v1 refuses rather than corrupt the DOF arrays" + ) + + coords = numpy.asarray(payload["coords"]) + expected_shape = numpy.asarray(self.X.coords).shape + if coords.shape != expected_shape: + raise SnapshotInvalidatedError( + f"mesh {self.name!r}: coordinate shape changed " + f"({coords.shape} vs current {expected_shape}); programming " + f"error since _mesh_version matched" + ) + self._deform_mesh(coords) + + current_vars = {var.clean_name: var for var in self.vars.values()} + for var_clean_name, saved_array in payload["vars"].items(): + var = current_vars.get(var_clean_name) + if var is None: + raise SnapshotInvalidatedError( + f"mesh {self.name!r}: variable {var_clean_name!r} " + f"from snapshot is no longer present" + ) + var._sync_lvec_to_gvec() + current_shape = numpy.asarray(var._gvec.array).shape + if saved_array.shape != current_shape: + raise SnapshotInvalidatedError( + f"mesh {self.name!r}: variable {var_clean_name!r} gvec " + f"shape changed ({saved_array.shape} vs current " + f"{current_shape})" + ) + var._gvec.array[...] = saved_array + iset, subdm = self.dm.createSubDM(var.field_id) + subdm.globalToLocal(var._gvec, var._lvec, addv=False) + iset.destroy() + subdm.destroy() + self._stale_lvec = True + @timing.routine_timer_decorator def write(self, filename: str, index: Optional[int] = None): """ diff --git a/src/underworld3/model.py b/src/underworld3/model.py index 8e1801d2..248aa745 100644 --- a/src/underworld3/model.py +++ b/src/underworld3/model.py @@ -126,6 +126,14 @@ class Model(PintNativeModelMixin, BaseModel): _variables: Dict[str, Any] = PrivateAttr(default_factory=dict) _solvers: Dict[str, Any] = PrivateAttr(default_factory=dict) + # State-bearing objects (DDt instances, parameter-history holders, + # any helper that exposes a .state dataclass per the Snapshottable + # protocol — see src/underworld3/checkpoint/snapshot.py). WeakSet so + # the registry does not pin objects past their natural lifetime. + # Snapshot capture and restore iterate this registry; consumers + # other than checkpoint may also walk it. + _state_bearers: Any = PrivateAttr(default_factory=weakref.WeakSet) + def __init__(self, name: Optional[str] = None, **kwargs): """ Initialize a new Model instance. @@ -565,6 +573,47 @@ def get_solver(self, name: str): """Get a solver by name from the model registry""" return self._solvers.get(name) + def _register_state_bearer(self, obj) -> None: + """Register a Snapshottable object with this model. + + Called by helper classes (DDt, parameter-history holders, ...) + in their ``__init__``. ``obj`` should expose a ``.state`` + attribute returning a dataclass with ``_schema_version``. + Membership is via WeakSet, so registration does not extend + ``obj``'s lifetime. + """ + self._state_bearers.add(obj) + + def snapshot(self, *, path: Optional[str] = None): + """Capture a unitary in-memory snapshot of the model's state. + + v1 covers mesh coordinates and mesh-variable DOFs across every + registered mesh. Subsequent PRs extend coverage to swarms and + solver-internal Python state. + + Pass ``path=...`` to write to an HDF5 file once the on-disk + backend lands (v1.1); v1 raises ``NotImplementedError``. + + See ``docs/developer/design/in_memory_checkpoint_design.md`` + for the full design. + """ + from underworld3.checkpoint import snapshot as _snapshot + + return _snapshot(self, path=path) + + def restore(self, snap) -> None: + """Restore the model from a :class:`underworld3.checkpoint.Snapshot`. + + Within-process restore: ``snap`` must have been produced by + :meth:`snapshot` on this same ``Model`` instance. Raises + :class:`underworld3.checkpoint.SnapshotInvalidatedError` if + the mesh has been adapted, or a captured mesh / variable is + no longer registered. + """ + from underworld3.checkpoint import restore as _restore + + return _restore(self, snap) + def define_parameter(self, name: str, ptype=None, **kwargs): """ Define a new parameter with validation rules. diff --git a/src/underworld3/swarm.py b/src/underworld3/swarm.py index ff8dc593..d0408e6e 100644 --- a/src/underworld3/swarm.py +++ b/src/underworld3/swarm.py @@ -2493,6 +2493,15 @@ def __init__(self, mesh, recycle_rate=0, verbose=False, clip_to_mesh=True): # Mesh version tracking for coordinate change detection self._mesh_version = mesh._mesh_version + # Informational counter incremented at every particle-population + # mutation site (populate, migrate, add_particles_*, advection + # remesh). NOT a snapshot-restore invalidation gate — restore + # rebuilds the local population from the snapshot regardless of + # what happened in between. Useful for logging, debugging, and + # any cache that wants to know "did the population change?" + # See docs/developer/design/in_memory_checkpoint_design.md. + self._population_generation = 0 + # Register this swarm with the mesh for coordinate change notifications mesh.register_swarm(self) @@ -3286,6 +3295,9 @@ def populate( offset = swarm_orig_size * i self._remeshed.data[offset::, 0] = i + # Informational: particle population just changed. + self._population_generation += 1 + return @timing.routine_timer_decorator @@ -3315,6 +3327,11 @@ def migrate( if self._migration_disabled: return + # Informational: migration may move or drop particles. Bump + # unconditionally; restore is not gated on this counter so a + # conservative no-op bump is harmless. + self._population_generation += 1 + from time import time if delete_lost_points is None: @@ -3536,6 +3553,10 @@ def add_particles_with_coordinates(self, coordinatesArray) -> int: if hasattr(var, "_canonical_data"): var._canonical_data = None + # Informational: addNPoints + direct dm.migrate path doesn't go + # through Swarm.migrate, so bump explicitly. + self._population_generation += 1 + return npoints @timing.routine_timer_decorator @@ -3604,6 +3625,10 @@ def add_particles_with_global_coordinates( self.dm.finalizeFieldRegister() self.dm.addNPoints(npoints=npoints) + # Informational: population changed even if migrate=False is + # passed (in which case Swarm.migrate's bump wouldn't fire). + self._population_generation += 1 + # Add new points with provided coords # Record the current rank (migration needs to know where we start from !) @@ -4056,6 +4081,163 @@ def vars(self): """ return self._vars + # ----- Unitary snapshot / restore ----- + # + # See ``src/underworld3/checkpoint/snapshot.py`` and + # ``docs/developer/design/in_memory_checkpoint_design.md``. Capture + # records the per-rank particle layout and user-variable arrays. + # Restore rebuilds the local population from the snapshot rather + # than refusing on counter mismatch — restore is precisely for the + # case where particles have moved / migrated / been repopulated. + + def _snapshot_stable_name(self) -> str: + """Per-process stable name. ``instance_number`` comes from uw_object.""" + return f"swarm_{self.instance_number}" + + def snapshot_payload(self) -> dict: + """Return a self-contained dict describing this swarm's state. + + Captured: per-rank particle coordinates (from + ``DMSwarmPIC_coor``) and every user swarm-variable's data + array. PETSc-internal variables (``DMSwarmPIC_coor``, + ``DMSwarm_X0``, ``DMSwarm_remeshed``) are excluded — their + contents either come from the captured coords or are + regenerated on the next solve. + """ + coord_field = self.dm.getField("DMSwarmPIC_coor").reshape( + (-1, self.dim) + ) + coords = np.asarray(coord_field).copy() + self.dm.restoreField("DMSwarmPIC_coor") + + var_arrays: dict = {} + for var in list(self._vars.values()): + if var.name.startswith("DMSwarm"): + continue + var_arrays[var.clean_name] = np.asarray(var.data).copy() + + return { + "name": self._snapshot_stable_name(), + "mesh_name": self.mesh.name, + "population_generation": int(self._population_generation), + "coords": coords, + "vars": var_arrays, + } + + def apply_snapshot_payload(self, payload: dict) -> None: + """Rebuild this swarm's local particle population from a payload. + + Algorithm: + + 1. Drop every current local particle (``dm.removePoint`` from + the end is O(1) per call, O(N) total). + 2. Add the captured-rank's particles back via the raw PETSc + primitives — ``addNPoints`` then writing the coord field + directly. We deliberately bypass + :meth:`add_particles_with_coordinates` because that method + filters via ``points_in_domain`` (slow) and triggers + ``dm.migrate`` (unnecessary — saved coords were already + local at capture time, and the mesh hasn't changed). + 3. Write captured per-variable data back. The local particle + count matches the captured count because we just put the + same particles back in the same order. + + This bumps ``_population_generation`` once (from the addNPoints + step in restore), which is correct: the population *did* just + change. Downstream consumers that care can compare against the + captured value in ``payload['captured_population_generation']``. + """ + from underworld3.checkpoint.snapshot import SnapshotInvalidatedError + + saved_coords = np.asarray(payload["coords"]) + + # Step 1: clear local population. removePoint() removes the last + # particle, so this is O(N) total. + while self.dm.getLocalSize() > 0: + self.dm.removePoint() + + # Step 2: re-add. add raw points, write coords + ranks directly. + n_saved = int(saved_coords.shape[0]) + if n_saved > 0: + self.dm.finalizeFieldRegister() + self.dm.addNPoints(npoints=n_saved) + + coord_field = self.dm.getField("DMSwarmPIC_coor").reshape( + (-1, self.dim) + ) + if coord_field.shape != saved_coords.shape: + self.dm.restoreField("DMSwarmPIC_coor") + raise SnapshotInvalidatedError( + f"swarm {self._snapshot_stable_name()!r}: after " + f"addNPoints({n_saved}) the coord field has shape " + f"{coord_field.shape}, expected {saved_coords.shape}" + ) + coord_field[...] = saved_coords + self.dm.restoreField("DMSwarmPIC_coor") + + rank_field = self.dm.getField("DMSwarm_rank") + rank_field[...] = uw.mpi.rank + self.dm.restoreField("DMSwarm_rank") + + # Invalidate canonical-data caches — the underlying arrays + # have been reallocated by the addNPoints path. + if hasattr(self._particle_coordinates, "_canonical_data"): + self._particle_coordinates._canonical_data = None + for var in self._vars.values(): + if hasattr(var, "_canonical_data"): + var._canonical_data = None + + # The raw PETSc primitives used above (removePoint loop + + # addNPoints + direct field writes) deliberately bypass + # Swarm.migrate / add_particles_with_coordinates, so they do + # not touch _population_generation. Bump it explicitly here + # for consistency with the other mutation sites — a restore + # IS a population change, downstream consumers should see it. + # (Comment rewritten per Copilot review on #195.) + self._population_generation += 1 + + # Step 3: write captured per-variable data. Per Copilot + # review on #195, also raise on the inverse direction — + # any user swarm variable on the LIVE swarm that wasn't in + # the snapshot would retain stale/uninitialised contents + # after the clear+addNPoints reallocation, which is exactly + # the silent-incoherence failure we want loud rather than + # quiet. The contract is symmetric with the mesh-variable + # restore: same variable set on both sides. + current_user_vars = { + var.clean_name: var + for var in self._vars.values() + if not var.name.startswith("DMSwarm") + } + captured_names = set(payload["vars"].keys()) + live_names = set(current_user_vars.keys()) + extras = live_names - captured_names + if extras: + raise SnapshotInvalidatedError( + f"swarm {self._snapshot_stable_name()!r}: variables " + f"{sorted(extras)!r} exist on the live swarm but were " + f"not in the snapshot. Restore would leave them with " + f"incoherent data after the population rebuild — add " + f"them before the snapshot was taken, or remove them " + f"before restoring." + ) + + for var_clean_name, saved in payload["vars"].items(): + var = current_user_vars.get(var_clean_name) + if var is None: + raise SnapshotInvalidatedError( + f"swarm {self._snapshot_stable_name()!r}: variable " + f"{var_clean_name!r} from snapshot is not present" + ) + current = np.asarray(var.data) + if current.shape != saved.shape: + raise SnapshotInvalidatedError( + f"swarm {self._snapshot_stable_name()!r}: variable " + f"{var_clean_name!r} data shape mismatch — current " + f"{current.shape} vs snapshot {saved.shape}" + ) + current[...] = saved + def _legacy_access(self, *writeable_vars: SwarmVariable): """ This context manager makes the underlying swarm variables data available to @@ -4470,6 +4652,9 @@ def advection( self.dm.addNPoints(num_remeshed_points) + # Informational: remesh just re-injected particles. + self._population_generation += 1 + ## cellid = self.dm.getField("DMSwarm_cellid") coords = self.dm.getField("DMSwarmPIC_coor").reshape((-1, self.dim)) rmsh = self.dm.getField("DMSwarm_remeshed") diff --git a/src/underworld3/systems/ddt.py b/src/underworld3/systems/ddt.py index 39bf5692..bea3d13e 100644 --- a/src/underworld3/systems/ddt.py +++ b/src/underworld3/systems/ddt.py @@ -54,17 +54,115 @@ from sympy import sympify import numpy as np -from typing import Optional, Callable, Union +from dataclasses import dataclass, field +from typing import Any, Optional, Callable, Union import underworld3 as uw from underworld3 import VarType import underworld3.timing as timing from underworld3.utilities._api_tools import uw_object +from underworld3.checkpoint.state import SnapshottableState from petsc4py import PETSc +# ----- Snapshot state dataclasses for DDt flavors ----- +# +# Per the design note's "General serialisation contract" section, each +# DDt class exposes a derived State dataclass via ``.state``. The +# private ``_dt_history`` / ``_history_initialised`` / etc. remain the +# authoritative store; the dataclass is built on read and unpacked on +# write. See ``src/underworld3/checkpoint/state.py``. +# +# PR 3 retrofits the Symbolic class. PR 4 will extend the pattern to +# Eulerian, SemiLagrangian, Lagrangian, and Lagrangian_Swarm — each has +# the same dt_history / history_initialised / n_solves_completed / dt +# core plus a flavor-specific psi_star shape. + + +@dataclass +class _DDtCoreState(SnapshottableState): + """Common evolution-tracking fields shared by every DDt flavor. + + Each concrete flavor extends this with its own psi_star + representation (sympy expressions for Symbolic; mesh-variable + names for Eulerian / SemiLagrangian; swarm-variable names for + Lagrangian / Lagrangian_Swarm). The actual variable DOF / particle + data lives in the mesh-variable or swarm-variable path of the + snapshot — these State dataclasses carry only the metadata needed + to re-bind on restore. + """ + + _schema_version: int = 1 + dt_history: list = field(default_factory=list) + history_initialised: bool = False + n_solves_completed: int = 0 + dt: Any = None + + +@dataclass +class DDtSymbolicState(_DDtCoreState): + """Snapshot of a :class:`Symbolic` DDt instance's evolution state. + + ``Symbolic`` is the pure-symbolic flavor — ``psi_star`` history + slots hold sympy expressions (immutable), captured by value. + """ + + psi_star: list = field(default_factory=list) + + +@dataclass +class DDtEulerianState(_DDtCoreState): + """Snapshot of an :class:`Eulerian` DDt instance. + + ``psi_star`` is a list of :class:`MeshVariable` objects; their + DOF arrays travel via the mesh-variable snapshot path. This State + only records the variable names for restore-side verification + that the binding still holds. + """ + + psi_star_var_names: list[str] = field(default_factory=list) + + +@dataclass +class DDtSemiLagrangianState(_DDtCoreState): + """Snapshot of a :class:`SemiLagrangian` DDt instance. + + Like :class:`DDtEulerianState`, plus an optional ``forcing_star`` + variable (when ``with_forcing_history=True``) used by ETD-2 + integration of the Maxwell relaxation operator. + """ + + psi_star_var_names: list[str] = field(default_factory=list) + forcing_star_var_name: Optional[str] = None + with_forcing_history: bool = False + + +@dataclass +class DDtLagrangianState(_DDtCoreState): + """Snapshot of a :class:`Lagrangian` DDt instance. + + ``psi_star`` is a list of :class:`SwarmVariable` objects on this + DDt's internal swarm. Their data travels via the swarm-variable + snapshot path. + """ + + psi_star_var_names: list[str] = field(default_factory=list) + + +@dataclass +class DDtLagrangianSwarmState(_DDtCoreState): + """Snapshot of a :class:`Lagrangian_Swarm` DDt instance. + + Same shape as :class:`DDtLagrangianState`; the difference is + operational (Lagrangian creates its own swarm, Lagrangian_Swarm + uses a user-provided one) rather than state-shaped. + """ + + psi_star_var_names: list[str] = field(default_factory=list) + + def _as_float(value): """Extract a plain float from various numeric types (Pint, UWQuantity, etc.).""" if value is None: @@ -472,8 +570,77 @@ def __init__( _update_am_values(self._am_coeffs, 1, self.theta) _update_exp_values(self._exp_coeffs, None, None) + # Register with the active default model as a Snapshottable + # state-bearer. Safe if no model is active. + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. + pass + return + # ----- Unitary snapshot / restore ----- + # + # Option (B)-style adapter per the design note: state is a derived + # dataclass that surfaces the mutable evolution-tracking attrs. + # The private ``_dt_history`` / ``_history_initialised`` / etc. + # remain the authoritative store; the State dataclass is built on + # read and unpacked on write. + # + # See ``docs/developer/design/in_memory_checkpoint_design.md`` and + # ``src/underworld3/checkpoint/state.py`` for the contract. + + @property + def state(self) -> "DDtSymbolicState": + """Return a snapshot-of-state dataclass for this DDt instance.""" + return DDtSymbolicState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star=list(self.psi_star), + ) + + @state.setter + def state(self, s: "DDtSymbolicState") -> None: + """Write a captured state back. Reconciles derived coefficients.""" + if s._schema_version != DDtSymbolicState._schema_version: + raise ValueError( + f"DDtSymbolicState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtSymbolicState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + if len(s.psi_star) != len(self.psi_star): + raise ValueError( + f"psi_star length mismatch ({len(s.psi_star)} vs " + f"{len(self.psi_star)}); order changed since snapshot?" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + self.psi_star = list(s.psi_star) + # Re-derive BDF/AM coefficients so downstream reads see values + # consistent with the restored primary state without waiting + # for the next update_pre_solve. + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + _update_am_values(self._am_coeffs, self.effective_order, self.theta) + @property def psi_fn(self): r"""Current symbolic expression :math:`\psi` being tracked.""" @@ -782,8 +949,60 @@ def __init__( _update_am_values(self._am_coeffs, 1, self.theta) _update_exp_values(self._exp_coeffs, None, None) + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. + pass + return + @property + def state(self) -> "DDtEulerianState": + """Return a snapshot-of-state dataclass for this Eulerian DDt.""" + return DDtEulerianState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star_var_names=[ps.clean_name for ps in self.psi_star], + ) + + @state.setter + def state(self, s: "DDtEulerianState") -> None: + if s._schema_version != DDtEulerianState._schema_version: + raise ValueError( + f"DDtEulerianState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtEulerianState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + current_names = [ps.clean_name for ps in self.psi_star] + if s.psi_star_var_names and s.psi_star_var_names != current_names: + raise ValueError( + f"psi_star variable names changed since snapshot: " + f"{s.psi_star_var_names} vs {current_names}" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + _update_am_values(self._am_coeffs, self.effective_order, self.theta) + @property def psi_fn(self): r"""Current symbolic expression :math:`\psi` being tracked.""" @@ -1380,8 +1599,73 @@ def __init__( self.I = uw.maths.Integral(mesh, None) + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. + pass + return + @property + def state(self) -> "DDtSemiLagrangianState": + return DDtSemiLagrangianState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star_var_names=[ps.clean_name for ps in self.psi_star], + forcing_star_var_name=( + self.forcing_star.clean_name + if self.forcing_star is not None else None + ), + with_forcing_history=bool(self.with_forcing_history), + ) + + @state.setter + def state(self, s: "DDtSemiLagrangianState") -> None: + if s._schema_version != DDtSemiLagrangianState._schema_version: + raise ValueError( + f"DDtSemiLagrangianState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtSemiLagrangianState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + current_names = [ps.clean_name for ps in self.psi_star] + if s.psi_star_var_names and s.psi_star_var_names != current_names: + raise ValueError( + f"psi_star variable names changed since snapshot: " + f"{s.psi_star_var_names} vs {current_names}" + ) + if s.with_forcing_history != bool(self.with_forcing_history): + raise ValueError( + f"with_forcing_history flag differs between snapshot " + f"({s.with_forcing_history}) and current " + f"({self.with_forcing_history})" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + # SemiLagrangian's update_pre_solve uses theta=0.5 directly + # (it doesn't take a theta argument in __init__), so the setter + # matches that. + _update_am_values(self._am_coeffs, self.effective_order, 0.5) + @property def psi_fn(self): r"""Current symbolic expression :math:`\psi` being tracked.""" @@ -2262,7 +2546,7 @@ def __init__( super().__init__() # create a new swarm to manage here - dudt_swarm = uw.swarm.UWSwarm(mesh) + dudt_swarm = uw.swarm.Swarm(mesh) self.mesh = mesh self.swarm = dudt_swarm @@ -2301,8 +2585,59 @@ def __init__( dudt_swarm.populate(fill_param) + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. + pass + return + @property + def state(self) -> "DDtLagrangianState": + return DDtLagrangianState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star_var_names=[ps.clean_name for ps in self.psi_star], + ) + + @state.setter + def state(self, s: "DDtLagrangianState") -> None: + if s._schema_version != DDtLagrangianState._schema_version: + raise ValueError( + f"DDtLagrangianState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtLagrangianState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + current_names = [ps.clean_name for ps in self.psi_star] + if s.psi_star_var_names and s.psi_star_var_names != current_names: + raise ValueError( + f"psi_star variable names changed since snapshot: " + f"{s.psi_star_var_names} vs {current_names}" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + _update_am_values(self._am_coeffs, self.effective_order, 0.5) + def _object_viewer(self): from IPython.display import Latex, Markdown, display @@ -2602,8 +2937,59 @@ def __init__( _update_bdf_values(self._bdf_coeffs, 1, None, []) _update_am_values(self._am_coeffs, 1, 0.5) + try: + import underworld3 as _uw + + _uw.get_default_model()._register_state_bearer(self) + except (ImportError, AttributeError): + # Narrowed per Copilot review on #195: only swallow the + # genuine bootstrap modes (import not yet wired during + # underworld3 init, or older Model without the registry + # method). Anything else propagates rather than silently + # masking a registration bug — exactly the silent-state- + # loss failure mode the design note warns against. + pass + return + @property + def state(self) -> "DDtLagrangianSwarmState": + return DDtLagrangianSwarmState( + dt_history=list(self._dt_history), + history_initialised=bool(self._history_initialised), + n_solves_completed=int(self._n_solves_completed), + dt=self._dt, + psi_star_var_names=[ps.clean_name for ps in self.psi_star], + ) + + @state.setter + def state(self, s: "DDtLagrangianSwarmState") -> None: + if s._schema_version != DDtLagrangianSwarmState._schema_version: + raise ValueError( + f"DDtLagrangianSwarmState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{DDtLagrangianSwarmState._schema_version}" + ) + if len(s.dt_history) != len(self._dt_history): + raise ValueError( + f"dt_history length mismatch ({len(s.dt_history)} vs " + f"{len(self._dt_history)}); order changed since snapshot?" + ) + current_names = [ps.clean_name for ps in self.psi_star] + if s.psi_star_var_names and s.psi_star_var_names != current_names: + raise ValueError( + f"psi_star variable names changed since snapshot: " + f"{s.psi_star_var_names} vs {current_names}" + ) + self._dt_history = list(s.dt_history) + self._history_initialised = bool(s.history_initialised) + self._n_solves_completed = int(s.n_solves_completed) + self._dt = s.dt + _update_bdf_values( + self._bdf_coeffs, self.effective_order, self._dt, self._dt_history + ) + _update_am_values(self._am_coeffs, self.effective_order, 0.5) + def _object_viewer(self): from IPython.display import Latex, Markdown, display diff --git a/tests/parallel/mpi_runner.sh b/tests/parallel/mpi_runner.sh index eed4a0c9..3b479dbe 100755 --- a/tests/parallel/mpi_runner.sh +++ b/tests/parallel/mpi_runner.sh @@ -24,3 +24,10 @@ mpirun -np 4 $PYTHON ./ptest_002_projection.py #mpirun -np 1 $PYTHON ./ptest_003_swarm_projection.py #echo "ptest 003 -np 4" #mpirun -np 4 $PYTHON ./ptest_003_swarm_projection.py + +echo "ptest 0007 snapshot in-memory -np 1" +mpirun -np 1 $PYTHON ./ptest_0007_snapshot_inmemory.py +echo "ptest 0007 snapshot in-memory -np 3 (uneven partition)" +mpirun -np 3 $PYTHON ./ptest_0007_snapshot_inmemory.py +echo "ptest 0007 snapshot in-memory -np 4" +mpirun -np 4 $PYTHON ./ptest_0007_snapshot_inmemory.py diff --git a/tests/parallel/ptest_0007_snapshot_inmemory.py b/tests/parallel/ptest_0007_snapshot_inmemory.py new file mode 100644 index 00000000..d533d046 --- /dev/null +++ b/tests/parallel/ptest_0007_snapshot_inmemory.py @@ -0,0 +1,187 @@ +"""Parallel (MPI) test of the in-memory snapshot toolkit. + +The single open production blocker for the snapshot toolkit was +"works everywhere" — i.e. correct under MPI. The design intent is +that swarm restore is a per-rank *reconstruction* (each rank clears +its local particles and re-adds the per-rank set it captured), not a +redistribution, so the global state is exactly reconstructed +regardless of any intervening cross-rank migration, provided the rank +count is unchanged. This script confirms that. + +Run (4 ranks exercises cross-rank migration properly): + + cd tests/parallel + mpirun -np 4 python ./ptest_0007_snapshot_inmemory.py + +Asserts (all collective, checked on rank 0): + + 1. Restore recovers the exact global particle count. The disruptive + step is deliberately allowed to *lose* particles across ranks + (advect out / clip) — that is exactly the failure stash-and- + restore exists to undo. The guarantee is that restore brings the + global count back to its pre-step value regardless. + 2. Exact reconstruction: gather every particle's (global-id, x, y, + material) across all ranks, sort by global id; the post-restore + sorted table equals the pre-step sorted table bit-for-bit. + Order- and rank-independent — this is the real proof that + per-rank reconstruction yields the correct global state under + cross-rank migration. + 3. Bit-identical continuation across a stash: a control run and a + run that took a disruptive step then restored and continued + produce bit-identical global sorted state and DDt history. +""" + +import numpy as np +import sympy +from mpi4py import MPI + +import underworld3 as uw + +comm = MPI.COMM_WORLD +rank = uw.mpi.rank +size = uw.mpi.size + + +def build(): + uw.reset_default_model() + model = uw.get_default_model() + # Wide box so a strip-partition genuinely splits particles across + # ranks; rotation field circulates them across the partition. + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(4.0, 1.0), cellSize=1.0 / 6.0 + ) + x_sym, y_sym = mesh.X + # Rotation about the box centre (2.0, 0.5): particles circulate, + # crossing the vertical rank-partition boundaries. + V_fn = sympy.Matrix([[-(y_sym - 0.5), 0.25 * (x_sym - 2.0)]]).T + + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + T.array[:, 0, 0] = 0.0 + + swarm = uw.swarm.Swarm(mesh) + gid = swarm.add_variable("gid", 1, dtype=float) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + + # Globally-unique, migration-stable particle id. Swarm variables + # travel with their particle through migration and through + # snapshot/restore, so this is a durable identity tag. + local_n = swarm.dm.getLocalSize() + counts = comm.allgather(local_n) + offset = int(np.sum(counts[:rank])) + gid.data[:, 0] = offset + np.arange(local_n, dtype=float) + material.data[:, 0] = swarm._particle_coordinates.data[:, 0] + + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + return uw, model, mesh, V_fn, T, swarm, gid, material, ddt + + +def step(uw, V_fn, T, swarm, ddt, dt): + ddt.update_pre_solve(dt) + swarm.advection(V_fn, delta_t=dt, step_limit=False) + T.array[:, 0, 0] = T.array[:, 0, 0] + dt + ddt.update_post_solve(dt) + + +def global_sorted_particles(swarm, gid, material): + """Gather (gid, x, y, material) from all ranks, sorted by gid. + + Order- and rank-independent canonical view of the whole swarm. + """ + g = gid.data[:, 0].copy() + coords = swarm._particle_coordinates.data.copy() + m = material.data[:, 0].copy() + local = np.column_stack([g, coords[:, 0], coords[:, 1], m]) + gathered = comm.allgather(local) + full = np.vstack([a for a in gathered if a.size]) if any( + a.size for a in gathered + ) else np.empty((0, 4)) + order = np.argsort(full[:, 0], kind="stable") + return full[order] + + +def main(): + uw, model, mesh, V_fn, T, swarm, gid, material, ddt = build() + + # Warm up: a few steps so particles have genuinely migrated across + # ranks before we snapshot. + for _ in range(3): + step(uw, V_fn, T, swarm, ddt, 0.1) + + pre = global_sorted_particles(swarm, gid, material) + pre_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + pre_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) + + snap = model.snapshot() + + # --- Property 1 + 2: a migration-inducing step, then restore --- + step(uw, V_fn, T, swarm, ddt, 0.3) # bigger dt -> more migration + mid_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + model.restore(snap) + post = global_sorted_particles(swarm, gid, material) + post_count = comm.allreduce(swarm.dm.getLocalSize(), op=MPI.SUM) + post_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) + + exact = np.array_equal(pre, post) + ddt_ok = pre_ddt == post_ddt + + # --- Property 3: bit-identical continuation across a stash --- + snap2 = model.snapshot() + for _ in range(4): + step(uw, V_fn, T, swarm, ddt, 0.1) + ctrl = global_sorted_particles(swarm, gid, material) + ctrl_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) + + model.restore(snap2) + step(uw, V_fn, T, swarm, ddt, 0.5) # the regretted step + model.restore(snap2) + for _ in range(4): + step(uw, V_fn, T, swarm, ddt, 0.1) + stash = global_sorted_particles(swarm, gid, material) + stash_ddt = (list(ddt.state.dt_history), ddt.state.n_solves_completed) + + cont_exact = np.array_equal(ctrl, stash) + cont_ddt_ok = ctrl_ddt == stash_ddt + + if rank == 0: + lost = pre_count - mid_count + print(f"[ranks={size}] particles total = {pre_count}", flush=True) + print( + f" disruptive step global count: {pre_count} -> {mid_count} " + f"-> {post_count} ({lost} particle(s) lost by the step, " + f"recovered by restore)", + flush=True, + ) + print(f" P1 restore recovers exact count: " + f"{pre_count == post_count}", flush=True) + print(f" P2 exact reconstruction: {exact}", flush=True) + print(f" P2 DDt state restored: {ddt_ok}", flush=True) + print(f" P3 bit-identical continuation: {cont_exact}", flush=True) + print(f" P3 DDt continuation identical: {cont_ddt_ok}", flush=True) + + # The disruptive step is *allowed* to lose particles across + # ranks — that is precisely the failure a stash-and-restore + # exists to undo. The guarantee is that restore brings the + # global count back exactly and every particle back to the + # right place (P2), regardless of what the step did. + assert pre_count == post_count, ( + f"restore did not recover the exact global particle count: " + f"pre={pre_count} post={post_count} (mid={mid_count})" + ) + assert exact, "swarm not exactly reconstructed after restore" + assert ddt_ok, "DDt state not restored" + assert cont_exact, ( + "continuation after stash is not bit-identical to control" + ) + assert cont_ddt_ok, "DDt continuation not bit-identical" + if lost == 0: + print( + " (note: this run's disruptive step happened not to " + "lose particles; the recovery guarantee still holds)", + flush=True, + ) + print(f"[ranks={size}] PASS", flush=True) + + +if __name__ == "__main__": + main() diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py new file mode 100644 index 00000000..e7e5bb0d --- /dev/null +++ b/tests/test_0007_snapshot_inmemory.py @@ -0,0 +1,771 @@ +import pytest + +pytestmark = [pytest.mark.level_1, pytest.mark.tier_a] + +import numpy as np + + +def _fresh_model_and_mesh(): + import underworld3 as uw + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 8.0 + ) + return uw, model, mesh + + +def test_meshvariable_in_memory_roundtrip(): + """Snapshot, scribble, restore: the MV global vector is recovered exactly.""" + uw, model, mesh = _fresh_model_and_mesh() + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + + T.array[:, 0, 0] = T.coords[:, 0] + 2.0 * T.coords[:, 1] + pre_array = np.asarray(T.array[...]).copy() + + snap = model.snapshot() + + T.array[...] = -42.0 + assert not np.allclose(np.asarray(T.array[...]), pre_array), "scribble didn't take" + + model.restore(snap) + + assert np.allclose(np.asarray(T.array[...]), pre_array, atol=0.0, rtol=0.0), ( + "MeshVariable.array is not bit-equivalent after restore" + ) + + +def test_multiple_meshvariables_roundtrip(): + """All MVs on a mesh are captured and restored, not just the first.""" + uw, model, mesh = _fresh_model_and_mesh() + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + V = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + + T.array[:, 0, 0] = 1.5 * T.coords[:, 0] + V.array[:, 0, 0] = 3.0 * V.coords[:, 0] + V.array[:, 0, 1] = 7.0 * V.coords[:, 1] + + T_pre = np.asarray(T.array[...]).copy() + V_pre = np.asarray(V.array[...]).copy() + + snap = model.snapshot() + + T.array[...] = 0.0 + V.array[...] = 0.0 + + model.restore(snap) + + assert np.allclose(np.asarray(T.array[...]), T_pre) + assert np.allclose(np.asarray(V.array[...]), V_pre) + + +def test_snapshot_is_independent_of_subsequent_writes(): + """Captured array is a copy; writes to the live MV don't leak into the snapshot.""" + uw, model, mesh = _fresh_model_and_mesh() + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + T.array[:, 0, 0] = 5.0 + + snap = model.snapshot() + T.array[...] = -1.0 + + # The backend still holds the captured value, not the post-write value. + keys = snap.backend.list_vectors() + var_key = next(k for k in keys if "var:T" in k) + captured = snap.backend.load_vector(var_key) + assert np.allclose(captured, 5.0), ( + "in-memory backend did not isolate the captured array from later writes" + ) + + +def test_mesh_version_invalidates_restore(): + """A bumped _mesh_version makes restore refuse rather than silently corrupt.""" + import underworld3 as uw + from underworld3.checkpoint import SnapshotInvalidatedError + + uw_, model, mesh = _fresh_model_and_mesh() + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + T.array[:, 0, 0] = 1.0 + + snap = model.snapshot() + + # Simulate a mesh-mutation event (e.g. adapt(), or any deformation + # routed through the high-level callback that bumps _mesh_version). + mesh._mesh_version += 1 + + with pytest.raises(SnapshotInvalidatedError, match="_mesh_version"): + model.restore(snap) + + +def test_restore_rejects_non_snapshot(): + """A bare dict / array is not a Snapshot; restore raises TypeError.""" + uw, model, mesh = _fresh_model_and_mesh() + _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + + with pytest.raises(TypeError): + model.restore({"not": "a snapshot"}) + + +def test_snapshot_path_is_v1_1_scope(): + """Passing path= raises NotImplementedError until the on-disk backend lands.""" + uw, model, mesh = _fresh_model_and_mesh() + _ = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + + with pytest.raises(NotImplementedError): + model.snapshot(path="/tmp/should_not_be_written.h5") + + +# ----- Swarm coverage (rebuild-on-restore semantics) ----- + + +def _fresh_model_mesh_and_swarm(with_material=True): + """Fresh model + mesh + populated swarm. svar must be added pre-populate.""" + import underworld3 as uw + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + swarm = uw.swarm.Swarm(mesh) + material = None + if with_material: + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + return uw, model, mesh, swarm, material + + +def _swarm_coords(swarm): + """Return a copy of the current per-rank particle coords.""" + field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) + out = np.asarray(field).copy() + swarm.dm.restoreField("DMSwarmPIC_coor") + return out + + +def test_swarm_no_change_roundtrip(): + """Trivial case: snapshot, scribble, restore — both coords and svar recovered.""" + uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() + material.data[:, 0] = 0.5 * _swarm_coords(swarm)[:, 0] + coords_pre = _swarm_coords(swarm) + material_pre = np.asarray(material.data).copy() + + snap = model.snapshot() + + coord_field = swarm.dm.getField("DMSwarmPIC_coor").reshape((-1, swarm.dim)) + coord_field[...] = -99.0 + swarm.dm.restoreField("DMSwarmPIC_coor") + material.data[...] = -99.0 + + model.restore(snap) + + assert np.allclose(_swarm_coords(swarm), coords_pre) + assert np.allclose(np.asarray(material.data), material_pre) + + +def test_swarm_restore_after_migrate(): + """Migrate between snapshot and restore: restore puts the swarm back. This + is the case my earlier counter-as-gate design wrongly refused.""" + uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() + material.data[:, 0] = 1.0 + coords_pre = _swarm_coords(swarm) + material_pre = np.asarray(material.data).copy() + pop_gen_pre = swarm._population_generation + + snap = model.snapshot() + + # Mutate: migrate() will bump the counter regardless of whether + # particles actually moved. Restore must succeed anyway. + swarm.migrate(remove_sent_points=True) + assert swarm._population_generation > pop_gen_pre, "migrate didn't bump counter" + + model.restore(snap) + + assert np.allclose(_swarm_coords(swarm), coords_pre), ( + "restore did not recover particle coords across a migrate event" + ) + assert np.allclose(np.asarray(material.data), material_pre), ( + "restore did not recover svar data across a migrate event" + ) + + +def test_swarm_restore_after_add_particles(): + """Particles added between snapshot and restore: restore *removes* them.""" + uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() + material.data[:, 0] = 2.0 + coords_pre = _swarm_coords(swarm) + material_pre = np.asarray(material.data).copy() + npre = swarm.dm.getLocalSize() + + snap = model.snapshot() + + swarm.add_particles_with_coordinates( + np.array([[0.5, 0.5], [0.25, 0.75]]) + ) + assert swarm.dm.getLocalSize() != npre, "add_particles didn't grow swarm" + + model.restore(snap) + + assert swarm.dm.getLocalSize() == npre, ( + "restore did not roll back to the captured particle count" + ) + assert np.allclose(_swarm_coords(swarm), coords_pre) + assert np.allclose(np.asarray(material.data), material_pre) + + +def test_swarm_population_generation_is_informational_not_a_gate(): + """The counter ticks up across mutations but does NOT block restore.""" + uw, model, mesh, swarm, _ = _fresh_model_mesh_and_swarm() + gen_at_capture = swarm._population_generation + + snap = model.snapshot() + + swarm.migrate(remove_sent_points=True) + swarm.add_particles_with_coordinates(np.array([[0.5, 0.5]])) + gen_during = swarm._population_generation + assert gen_during > gen_at_capture + + # Restore is expected to *succeed*, not raise. + model.restore(snap) + + # And the counter has moved on from where it was at capture, + # because restore itself counts as a population change. + assert swarm._population_generation > gen_at_capture + + +def test_swarm_internal_variables_are_not_captured(): + """Internal DMSwarm_* variables stay out of the snapshot key list.""" + uw, model, mesh, swarm, material = _fresh_model_mesh_and_swarm() + + snap = model.snapshot() + keys = snap.backend.list_vectors() + swarm_name = swarm._snapshot_stable_name() + svar_keys = [k for k in keys if k.startswith(f"swarm:{swarm_name}:var:")] + captured_names = {k.split(":var:")[1].split(":data")[0] for k in svar_keys} + + assert "material" in captured_names + assert not any(n.startswith("DMSwarm") for n in captured_names) + + +# ----- State-as-dataclass contract: Symbolic DDt ----- + + +def _fresh_model_mesh_and_symbolic_ddt(order=2): + import underworld3 as uw + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + ddt = uw.systems.ddt.Symbolic(T.sym, order=order) + return uw, model, mesh, T, ddt + + +def test_symbolic_ddt_registers_with_model(): + """A fresh DDt auto-registers in Model._state_bearers.""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt() + assert ddt in model._state_bearers + + +def test_symbolic_ddt_state_is_a_snapshottable_dataclass(): + """``.state`` returns a SnapshottableState (DDtSymbolicState) with the + expected schema version.""" + from underworld3.checkpoint import SnapshottableState + + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + state = ddt.state + assert isinstance(state, SnapshottableState) + assert state._schema_version == 1 + # Fresh DDt: order-sized dt_history, not initialised, zero solves. + assert state.dt_history == [None, None] + assert state.history_initialised is False + assert state.n_solves_completed == 0 + + +def test_symbolic_ddt_roundtrip_recovers_state(): + """Snapshot mid-trajectory, advance, restore, state equals captured.""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + + # Advance two solves so dt_history fills. + ddt.update_pre_solve(dt=0.1) + ddt.update_post_solve(dt=0.1) + ddt.update_pre_solve(dt=0.2) + ddt.update_post_solve(dt=0.2) + state_pre = ddt.state + # Sanity: history is populated. + assert state_pre.history_initialised is True + assert state_pre.n_solves_completed == 2 + assert state_pre.dt_history == [0.2, 0.1] + + snap = model.snapshot() + + # Mutate: take another solve, dt_history changes. + ddt.update_pre_solve(dt=0.5) + ddt.update_post_solve(dt=0.5) + assert ddt.state.dt_history == [0.5, 0.2] + + model.restore(snap) + + # Primary state is back to captured. + state_post = ddt.state + assert state_post.dt_history == state_pre.dt_history + assert state_post.history_initialised == state_pre.history_initialised + assert state_post.n_solves_completed == state_pre.n_solves_completed + assert state_post.dt == state_pre.dt + + +def test_symbolic_ddt_restore_rejects_wrong_schema_version(): + """Hand-built state with wrong _schema_version is refused on apply.""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + bad_state = ddt.state + bad_state._schema_version = 999 + + with pytest.raises(ValueError, match="schema version"): + ddt.state = bad_state + + +def test_symbolic_ddt_restore_rejects_order_mismatch(): + """Restoring a state captured at a different order raises (programming- + error guard; in practice this shouldn't happen within a single run).""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + bad_state = ddt.state + bad_state.dt_history = [0.1, 0.2, 0.3] # length 3 != order 2 + + with pytest.raises(ValueError, match="dt_history length mismatch"): + ddt.state = bad_state + + +def test_symbolic_ddt_snapshot_is_deep_copy(): + """Mutating the live DDt after snapshot doesn't leak into the + captured state-bearer payload.""" + uw, model, mesh, T, ddt = _fresh_model_mesh_and_symbolic_ddt(order=2) + ddt.update_pre_solve(dt=0.1) + ddt.update_post_solve(dt=0.1) + + snap = model.snapshot() + captured_state = snap.state_bearers[0][1] # (key, state) + captured_dt_history = list(captured_state.dt_history) + + # Scribble the live DDt's internal state — must not leak into snapshot. + ddt._dt_history[0] = -999.0 + ddt._n_solves_completed = 42 + + assert captured_state.dt_history == captured_dt_history + assert captured_state.n_solves_completed != 42 + + +# ----- State-as-dataclass: other DDt flavors ----- +# +# Construction-side smoke tests + roundtrip. We exercise the .state / +# .state.setter mechanics directly rather than running full solves; +# the BDF/AM coefficient re-derivation happens in the setter, so a +# manual primary-state mutation is enough to validate the retrofit. + + +def test_eulerian_ddt_roundtrip(): + import underworld3 as uw + from underworld3.systems.ddt import DDtEulerianState + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + ddt = uw.systems.ddt.Eulerian( + mesh, T.sym, uw.VarType.SCALAR, degree=1, continuous=True, order=2 + ) + assert ddt in model._state_bearers + assert isinstance(ddt.state, DDtEulerianState) + + # Manually advance state (avoid running real projections). + ddt._dt_history = [0.2, 0.1] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.2 + state_pre = ddt.state + + snap = model.snapshot() + + ddt._dt_history = [0.99, 0.99] + ddt._history_initialised = False + ddt._n_solves_completed = 0 + ddt._dt = None + + model.restore(snap) + + assert ddt.state.dt_history == state_pre.dt_history + assert ddt.state.history_initialised == state_pre.history_initialised + assert ddt.state.n_solves_completed == state_pre.n_solves_completed + assert ddt.state.dt == state_pre.dt + # psi_star names are stable bindings — must match. + assert ddt.state.psi_star_var_names == state_pre.psi_star_var_names + + +def test_semilagrangian_ddt_roundtrip(): + import underworld3 as uw + from underworld3.systems.ddt import DDtSemiLagrangianState + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + V = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + ddt = uw.systems.ddt.SemiLagrangian( + mesh, T.sym, V.sym, uw.VarType.SCALAR, degree=1, continuous=True, order=2 + ) + assert ddt in model._state_bearers + state = ddt.state + assert isinstance(state, DDtSemiLagrangianState) + assert state.with_forcing_history is False + assert state.forcing_star_var_name is None + + ddt._dt_history = [0.3, 0.1] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.3 + state_pre = ddt.state + + snap = model.snapshot() + ddt._dt_history = [None, None] + ddt._history_initialised = False + ddt._n_solves_completed = 0 + model.restore(snap) + + assert ddt.state.dt_history == state_pre.dt_history + assert ddt.state.history_initialised is True + assert ddt.state.n_solves_completed == 2 + + +def test_lagrangian_ddt_roundtrip(): + """Lagrangian creates its own internal swarm; the fix in this PR + restored uw.swarm.Swarm in __init__ (was a typo'd UWSwarm).""" + import underworld3 as uw + from underworld3.systems.ddt import DDtLagrangianState + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + V = uw.discretisation.MeshVariable("V", mesh, 2, degree=2) + ddt = uw.systems.ddt.Lagrangian( + mesh=mesh, psi_fn=T.sym, V_fn=V.sym, + vtype=uw.VarType.SCALAR, degree=1, continuous=True, order=2, + ) + assert ddt in model._state_bearers + assert isinstance(ddt.state, DDtLagrangianState) + + ddt._dt_history = [0.2, 0.1] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.2 + state_pre = ddt.state + + snap = model.snapshot() + ddt._dt_history = [None, None] + ddt._history_initialised = False + ddt._n_solves_completed = 0 + model.restore(snap) + + assert ddt.state.dt_history == state_pre.dt_history + assert ddt.state.history_initialised is True + assert ddt.state.n_solves_completed == 2 + assert ddt.state.psi_star_var_names == state_pre.psi_star_var_names + + +# ----- End-to-end back-stepping demonstration ----- +# +# Everything above this comment is unit-style: build a thing, snapshot, +# scribble, restore, check equality. This block exercises the toolkit's +# actual reason for existing: a *real* time-stepping use case where the +# consumer takes a step, detects it was bad, snapshots back, and retries +# with smaller Δt. The pattern is canonical adaptive-Δt CFL control; +# the snapshot mechanism is the thing that makes "snap back" possible +# without manually unwinding mesh / swarm / DDt state. + + +def test_backstepping_cfl_recovery_end_to_end(): + """Canonical adaptive-Δt back-step demonstration. + + Set up a swarm advecting in a known velocity field, with a + material variable carried along and a Symbolic DDt accumulating + BDF history. Take one too-large Δt step → CFL violation + (max-particle-displacement exceeds the mesh cell radius). Detect + it. Restore the snapshot. Retry with a smaller Δt → CFL satisfied, + state evolves cleanly. The full triple of state (swarm positions, + material variable, DDt history) is recovered on restore. + """ + import underworld3 as uw + import sympy + import numpy as np + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 8.0 + ) + + # Outward-radial velocity from the box centre. |V| ranges from 0 + # at the centre to ~0.71 at the corners — pick Δt to give a + # genuine CFL violation rather than tweak parameters to fit. + x, y = mesh.X + V_fn = sympy.Matrix([[x - 0.5, y - 0.5]]).T + + swarm = uw.swarm.Swarm(mesh) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + coords_initial = swarm._particle_coordinates.data.copy() + material.data[:, 0] = coords_initial[:, 0] # carry x as marker + material_initial = np.asarray(material.data).copy() + + # A separate DDt manages BDF history for a scalar field on the + # mesh. Advance it manually past startup so its captured state is + # non-trivial. + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + ddt._dt_history = [0.05, 0.05] + ddt._history_initialised = True + ddt._n_solves_completed = 2 + ddt._dt = 0.05 + ddt_state_initial = ddt.state + + # The user's CFL threshold: a particle moving more than one cell + # radius in a single step has crossed an element. min_radius is + # the standard UW3 cell-size proxy. + cfl_threshold = mesh.get_min_radius() + + # Take the snapshot *before* the speculative step. Everything that + # will be touched gets captured. + snap = model.snapshot() + + # Speculative step at the candidate Δt. Bigger than the user + # thinks is safe — they'll check after and back-step if it isn't. + candidate_dt = 0.5 + swarm.advection(V_fn, delta_t=candidate_dt, step_limit=False) + + # CFL check: max displacement among local particles. + coords_after_bad = swarm._particle_coordinates.data + max_disp_bad = np.max( + np.linalg.norm(coords_after_bad - coords_initial, axis=1) + ) + assert max_disp_bad > cfl_threshold, ( + f"speculative step at dt={candidate_dt} should violate CFL " + f"(max_disp={max_disp_bad:.4f} vs threshold {cfl_threshold:.4f})" + ) + + # Back-step. Everything captured is brought back to the snapshot + # point — swarm positions, the material variable carried with the + # swarm, and the DDt's BDF history. + model.restore(snap) + + assert np.allclose(swarm._particle_coordinates.data, coords_initial), ( + "particle positions did not roll back after restore" + ) + assert np.allclose(np.asarray(material.data), material_initial), ( + "swarm-variable data did not roll back after restore" + ) + assert ddt.state.dt_history == ddt_state_initial.dt_history, ( + "DDt history did not roll back after restore" + ) + assert ddt.state.n_solves_completed == ddt_state_initial.n_solves_completed + + # Retry with a smaller Δt. CFL now satisfied. + retry_dt = candidate_dt / 10.0 + swarm.advection(V_fn, delta_t=retry_dt, step_limit=False) + + coords_after_good = swarm._particle_coordinates.data + max_disp_good = np.max( + np.linalg.norm(coords_after_good - coords_initial, axis=1) + ) + assert max_disp_good < cfl_threshold, ( + f"retry at dt={retry_dt} should satisfy CFL " + f"(max_disp={max_disp_good:.4f} vs threshold {cfl_threshold:.4f})" + ) + # Sanity: smaller dt produced strictly smaller displacement. + assert max_disp_good < max_disp_bad / 5.0 + + +def test_lagrangian_swarm_ddt_registers_and_state_type(): + """Lagrangian_Swarm must be constructed before swarm.populate; the + retrofit registers it and exposes a typed state. Roundtrip is not + exercised here because advection requires a velocity-field setup + beyond the scope of a core unit test.""" + import underworld3 as uw + from underworld3.systems.ddt import DDtLagrangianSwarmState + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 4.0 + ) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + swarm = uw.swarm.Swarm(mesh) + ddt = uw.systems.ddt.Lagrangian_Swarm( + swarm=swarm, + psi_fn=T.sym, + vtype=uw.VarType.SCALAR, + degree=1, + continuous=True, + order=2, + ) + swarm.populate(fill_param=2) + + assert ddt in model._state_bearers + assert isinstance(ddt.state, DDtLagrangianSwarmState) + assert len(ddt.state.dt_history) == 2 + assert ddt.state.psi_star_var_names # non-empty + + +# ----- Bit-identical continuation (the core production guarantee) ----- +# +# Everything above proves *state equality after restore*. That is +# necessary but not the actual guarantee a backtracking consumer +# relies on. The guarantee is: after restore, *continuing the +# simulation* reproduces the trajectory of a run that never took the +# discarded step. These two tests assert that, bit-for-bit +# (np.array_equal — no tolerance), with a swarm + mesh variable + +# Symbolic DDt all live so the mesh -> swarm -> state-bearer restore +# ordering is exercised together. + + +def _build_continuation_fixture(): + """Mesh + swarm(+material) + a driven mesh variable + Symbolic DDt. + + Returns everything needed to run a deterministic step loop. + """ + import underworld3 as uw + import sympy + + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 6.0 + ) + x_sym, y_sym = mesh.X + V_fn = sympy.Matrix([[x_sym - 0.5, y_sym - 0.5]]).T + + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=1) + T.array[:, 0, 0] = 0.0 + + swarm = uw.swarm.Swarm(mesh) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + material.data[:, 0] = np.linalg.norm( + swarm._particle_coordinates.data - 0.5, axis=1 + ) + + ddt = uw.systems.ddt.Symbolic(T.sym, order=2) + + return uw, model, mesh, V_fn, T, swarm, material, ddt + + +def _step(uw, V_fn, T, swarm, ddt, dt): + """One deterministic step: advect swarm, evolve T by a fixed rule, + advance the DDt history. No solver, no randomness.""" + ddt.update_pre_solve(dt) + swarm.advection(V_fn, delta_t=dt, step_limit=False) + # Deterministic, history-free field update so T carries evolving + # state through the mesh-variable snapshot path. + T.array[:, 0, 0] = T.array[:, 0, 0] + dt + ddt.update_post_solve(dt) + + +def _capture_full_state(T, swarm, material, ddt): + """Everything that must match for bit-identical continuation.""" + return { + "T": np.asarray(T.array[...]).copy(), + "coords": swarm._particle_coordinates.data.copy(), + "material": np.asarray(material.data).copy(), + "dt_history": list(ddt.state.dt_history), + "n_solves": ddt.state.n_solves_completed, + "ddt_dt": ddt.state.dt, + } + + +def _assert_bit_identical(a, b, label): + assert np.array_equal(a["T"], b["T"]), f"{label}: T differs" + assert np.array_equal(a["coords"], b["coords"]), ( + f"{label}: swarm coords differ" + ) + assert np.array_equal(a["material"], b["material"]), ( + f"{label}: swarm material differs" + ) + assert a["dt_history"] == b["dt_history"], ( + f"{label}: DDt dt_history differs ({a['dt_history']} vs " + f"{b['dt_history']})" + ) + assert a["n_solves"] == b["n_solves"], f"{label}: DDt n_solves differs" + assert a["ddt_dt"] == b["ddt_dt"], f"{label}: DDt dt differs" + + +def test_continuation_deterministic_after_restore(): + """snapshot S -> K steps -> A; restore(S) -> K steps -> B. + A and B must be bit-identical. Proves restore leaves no residual + state that perturbs subsequent evolution.""" + uw, model, mesh, V_fn, T, swarm, material, ddt = ( + _build_continuation_fixture() + ) + + # Advance to a non-trivial state before snapshotting (fill DDt + # history, move particles off their lattice). + for _ in range(3): + _step(uw, V_fn, T, swarm, ddt, 0.05) + + snap = model.snapshot() + + # Branch A: K steps straight from S. + for _ in range(5): + _step(uw, V_fn, T, swarm, ddt, 0.05) + state_A = _capture_full_state(T, swarm, material, ddt) + + # Branch B: restore S, then the identical K steps. + model.restore(snap) + for _ in range(5): + _step(uw, V_fn, T, swarm, ddt, 0.05) + state_B = _capture_full_state(T, swarm, material, ddt) + + _assert_bit_identical(state_A, state_B, "deterministic-continuation") + + +def test_continuation_bit_identical_across_stash_and_recover(): + """The real 'git stash for steps' guarantee: + + control: S -> K good steps -> ctrl + stash: S -> bad disruptive step -> restore(S) + -> same K good steps -> stash + + ctrl and stash must be bit-identical: the discarded step must + leave no trace whatsoever after restore + continuation.""" + uw, model, mesh, V_fn, T, swarm, material, ddt = ( + _build_continuation_fixture() + ) + + for _ in range(3): + _step(uw, V_fn, T, swarm, ddt, 0.05) + + snap = model.snapshot() + + # Control: K good steps from S. + for _ in range(5): + _step(uw, V_fn, T, swarm, ddt, 0.05) + ctrl = _capture_full_state(T, swarm, material, ddt) + + # Stash scenario: back to S, take a deliberately disruptive step + # (10x Δt — large advection, big T jump, DDt history shift), then + # discard it via restore and run the intended K good steps. + model.restore(snap) + _step(uw, V_fn, T, swarm, ddt, 0.5) # the regretted step + model.restore(snap) + for _ in range(5): + _step(uw, V_fn, T, swarm, ddt, 0.05) + stash = _capture_full_state(T, swarm, material, ddt) + + _assert_bit_identical(ctrl, stash, "stash-and-recover") + + diff --git a/tests/test_0008_snapshot_realsolver.py b/tests/test_0008_snapshot_realsolver.py new file mode 100644 index 00000000..33afe4a3 --- /dev/null +++ b/tests/test_0008_snapshot_realsolver.py @@ -0,0 +1,195 @@ +"""Real-solver confidence test for the snapshot toolkit. + +Every other snapshot test drives state by hand. This one runs an +actual PETSc solver — AdvDiffusion, which carries an internal +SemiLagrangian DDt (auxiliary projection SNES + nodal trace-back +swarm) — through the stash-and-recover loop. + +Findings this test pins down (all verified, see commit message): + + * The AdvDiffusion solve is bit-deterministic: two independent + identical runs with no snapshot are bit-for-bit equal. + + * restore() recovers the primary solution field T bit-exactly. + + * THE CORE GUARANTEE — a discarded ("regretted") step leaves zero + trace, bit-for-bit, even through real solves: + restore -> regretted solve -> restore -> K solves + is np.array_equal to + restore -> K solves + (B == C, max|d| = 0.0). + + * The only residual is restore's reproducibility *floor* against a + never-snapshotted control (~7e-7 here): restore resyncs fields + through gvec->lvec rather than reproducing the solver-produced + lvec exactly, and the implicit diffusion operator amplifies that + to solver-tolerance level over steps. This is NOT contamination + from the discarded step (proven by B == C above); it is the cost + of round-tripping through the snapshot representation, within + solver tolerance, and consistent with the design intent that + auxiliary solver state is intentionally not captured. + +So the honest production statement: discarding a bad step is +bit-exact even through real solvers; recovering to a never-stashed +control is within solver tolerance. +""" + +import numpy as np +import sympy as sp +import pytest + +pytestmark = [pytest.mark.level_2, pytest.mark.tier_a] + +# Restore-vs-pristine reproducibility floor for this setup. +# The regretted-step guarantee is asserted bit-exact (np.array_equal); +# only the never-stashed-control comparison uses this tolerance. +# +# The measured floor on this setup is ~7e-7 (see commit 3efc31b); +# the assertion threshold is set ~14× looser at 1e-5 deliberately, +# as headroom for variability across PETSc versions / BLAS libs / +# MPI ranks on CI. Do not tighten without measuring on the target +# environments first. (Comment added per Copilot review on #195.) +_RESTORE_FLOOR_ATOL = 1.0e-5 + + +@pytest.fixture(autouse=True) +def _reset(): + import underworld3 as uw + + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + yield + uw.reset_default_model() + uw.use_strict_units(False) + uw.use_nondimensional_scaling(False) + + +def _build(): + import underworld3 as uw + + model = uw.get_default_model() + mesh = uw.meshing.StructuredQuadBox( + elementRes=(16, 16), + minCoords=(0.0, 0.0), + maxCoords=(1.0, 1.0), + qdegree=3, + ) + v = uw.discretisation.MeshVariable("U", mesh, mesh.dim, degree=1) + T = uw.discretisation.MeshVariable("T", mesh, 1, degree=2) + + adv_diff = uw.systems.AdvDiffusion(mesh, u_Field=T, V_fn=v) + adv_diff.constitutive_model = uw.constitutive_models.DiffusionModel + adv_diff.constitutive_model.Parameters.diffusivity = 1.0 + adv_diff.add_dirichlet_bc(0.0, "Left") + adv_diff.add_dirichlet_bc(0.0, "Right") + v.array[:, 0, 0] = 0.05 + + x, y = mesh.X + T.array = uw.function.evaluate( + sp.sin(sp.pi * x) * sp.sin(sp.pi * y), T.coords + ) + return uw, model, mesh, adv_diff, T + + +def _capture(T): + return np.asarray(T.array[...]).copy() + + +def test_realsolver_restore_recovers_solution_field(): + """snapshot, do a regretted solve, restore — the solution field + itself is recovered exactly.""" + uw, model, mesh, adv_diff, T = _build() + + for _ in range(2): + adv_diff.solve(timestep=1.0e-3) + + pre_T = _capture(T) + snap = model.snapshot() + + adv_diff.solve(timestep=5.0) # absurd Δt: converges, over-diffused + assert not np.allclose(_capture(T), pre_T, atol=1e-8), ( + "the regretted solve was not actually disruptive" + ) + + model.restore(snap) + assert np.array_equal(_capture(T), pre_T), ( + "solution field not exactly recovered after restore" + ) + + +def test_realsolver_regretted_step_leaves_no_trace(): + """THE core guarantee, through a real solver, bit-for-bit. + + B: restore -> K good solves + C: restore -> regretted absurd-Δt solve -> restore + -> same K good solves + + B == C exactly. The discarded step leaves zero trace even though + it ran a real PETSc solve in between. + """ + uw, model, mesh, adv_diff, T = _build() + + for _ in range(3): + adv_diff.solve(timestep=1.0e-3) + snap = model.snapshot() + + # B: restore, then K good solves. + model.restore(snap) + for _ in range(4): + adv_diff.solve(timestep=1.0e-3) + B = _capture(T) + + # C: restore, a regretted solve, restore, the same K good solves. + model.restore(snap) + adv_diff.solve(timestep=5.0) + model.restore(snap) + for _ in range(4): + adv_diff.solve(timestep=1.0e-3) + C = _capture(T) + + assert np.array_equal(B, C), ( + "regretted real solve left a trace after restore — " + f"max abs diff {np.max(np.abs(B - C)):.3e} (expected exactly 0)" + ) + + +def test_realsolver_continuation_within_solver_tolerance(): + """Recovering to a *never-stashed* control is within solver + tolerance (not bit-exact): restore resyncs fields gvec->lvec + rather than reproducing the solver-produced lvec, and the + implicit operator amplifies that to ~tolerance over steps. This + is the documented restore floor, consistent with not capturing + auxiliary solver state by design.""" + uw, model, mesh, adv_diff, T = _build() + + for _ in range(3): + adv_diff.solve(timestep=1.0e-3) + snap = model.snapshot() + + # Control: never snapshotted/restored — straight K solves. + for _ in range(4): + adv_diff.solve(timestep=1.0e-3) + ctrl = _capture(T) + + # Stash path: restore, regretted solve, restore, same K solves. + model.restore(snap) + adv_diff.solve(timestep=5.0) + model.restore(snap) + for _ in range(4): + adv_diff.solve(timestep=1.0e-3) + stash = _capture(T) + + # Not bit-exact vs a never-stashed control (that is the floor), + # but well within solver tolerance and far below the solution + # scale (~0.04 here). + max_diff = float(np.max(np.abs(stash - ctrl))) + assert max_diff < _RESTORE_FLOOR_ATOL, ( + f"continuation drifted beyond the restore floor: " + f"max abs diff {max_diff:.3e} >= {_RESTORE_FLOOR_ATOL:.0e}" + ) + assert not np.array_equal(stash, ctrl), ( + "continuation is unexpectedly bit-exact vs a never-stashed " + "control — if this starts passing, the restore floor has been " + "eliminated and this test should be tightened to np.array_equal" + )