Add focus ordering and ispta balancing #449#452
Add focus ordering and ispta balancing #449#452samueljwu wants to merge 1 commit intoOpenwaterHealth:mainfrom
Conversation
9e65a5d to
00e575b
Compare
|
TY for your updated work! Have you fully addressed Peter's review from #450 and are you ready for another review? |
|
Yes, they are all addressed in this commit. Please let me know if there's anything that needs adjustments! Thanks! |
There was a problem hiding this comment.
Pull request overview
Adds optional focus ordering to pulse sequences so per-focus pulse hit counts can be non-uniform, and introduces ISPTA-based balancing during solution scaling to redistribute pulses toward weaker foci while keeping total pulse count constant.
Changes:
- Add
Sequence.focus_orderwith validation and serialization support. - Add
Solutionhelpers for deriving focus order/counts and for computing/building balanced focus allocations. - Update
Protocol.calc_solution()to (a) skip the pulse-count divisibility fix whenfocus_orderis present, (b) weight aggregated intensity by focus hit counts, and (c) passProtocol.scaling_optionsthrough toSolution.scale().
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
src/openlifu/bf/sequence.py |
Adds focus_order field + validation and table output. |
src/openlifu/plan/solution.py |
Adds focus-order/count helpers and ISPTA balancing logic; extends analyze()/get_ita() signatures. |
src/openlifu/plan/protocol.py |
Adds scaling_options, skips pulse mismatch handling when focus_order is set, and weights intensity aggregation by focus counts. |
tests/test_sequence.py |
Extends dict round-trip test and adds focus_order validation tests. |
tests/test_solution.py / tests/test_protocol.py |
Adds tests for focus-counting, balancing helpers, and skipping pulse mismatch when focus_order exists. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """Get the focus index order for each pulse.""" | ||
| if self.sequence.focus_order is not None: | ||
| return np.array(self.sequence.focus_order) | ||
| return (np.arange(self.sequence.pulse_count) - 1) % self.num_foci() + 1 |
There was a problem hiding this comment.
get_focus_order() builds the default round-robin sequence as (np.arange(pulse_count) - 1) % num_foci + 1, which starts the sequence at num_foci (e.g., for 3 foci it yields [3,1,2,...]) rather than [1,2,3,...]. This will skew focus counts/weighting whenever focus_order is not explicitly provided. Consider using np.arange(pulse_count) % num_foci + 1 (and handle num_foci()==0 with a clear error or empty result).
| return (np.arange(self.sequence.pulse_count) - 1) % self.num_foci() + 1 | |
| num_foci = self.num_foci() | |
| if num_foci == 0: | |
| raise ValueError("Cannot compute default focus order when there are no foci") | |
| return np.arange(self.sequence.pulse_count) % num_foci + 1 |
| if focus_counts is None: | ||
| focus_counts = self.get_focus_counts() | ||
| focus_counts = np.asarray(focus_counts) | ||
| if focus_counts.shape != (self.num_foci(),): | ||
| raise ValueError(f"Focus counts must have one value per focus ({self.num_foci()})") | ||
| if np.any(focus_counts < 0): | ||
| raise ValueError("Focus counts must be non-negative") | ||
| counts = focus_counts.reshape((1, 1, 1, self.num_foci())) | ||
| intensity = intensity_scaled.copy(deep=True) | ||
| isppa_avg = np.sum(np.expand_dims(intensity.data, axis=-1) * counts, axis=-1) / np.sum(counts) | ||
| intensity.data = isppa_avg * pulsetrain_dutycycle * treatment_dutycycle |
There was a problem hiding this comment.
get_ita() currently computes isppa_avg via np.sum(np.expand_dims(intensity.data, axis=-1) * counts, axis=-1) / np.sum(counts). With the current array shapes/dims (intensity has a focal_point_index dimension), this does not actually apply the per-focus weights; it effectively cancels out and leaves the focal_point_index dimension intact. That means focus_counts has no effect on ITA, undermining ISPTA balancing. Consider computing a weighted mean over the focal_point_index dimension (e.g., (intensity_scaled * focus_weights).sum(dim='focal_point_index')) and returning an ITA DataArray without focal_point_index so downstream analysis uses the treatment-averaged intensity.
| if (self.sequence.focus_order is not None and len(self.foci) > 0 and max(self.sequence.focus_order) > len(self.foci)): | ||
| raise ValueError(f"Focus order index {max(self.sequence.focus_order)} exceeds number of foci ({len(self.foci)})") |
There was a problem hiding this comment.
Solution.__post_init__ validates focus_order using max(self.sequence.focus_order), which will raise a built-in ValueError if focus_order is later mutated to [] (dataclass fields are mutable) and it only checks the upper bound. If the intent is to validate focus_order at solution creation time, consider guarding against empty lists and validating lower bound / length as well (or centralizing validation in a helper so Sequence construction and later mutations are checked consistently).
| if self.sequence.focus_order is not None and max(self.sequence.focus_order) > len(foci): | ||
| raise ValueError(f"Focus order index {max(self.sequence.focus_order)} exceeds number of foci ({len(foci)})") | ||
|
|
||
| # updating solution sequence if pulse mismatch | ||
| if (self.sequence.pulse_count % len(foci)) != 0: | ||
| if self.sequence.focus_order is None and (self.sequence.pulse_count % len(foci)) != 0: | ||
| self.fix_pulse_mismatch(on_pulse_mismatch, foci) |
There was a problem hiding this comment.
calc_solution() only validates focus_order via max(self.sequence.focus_order) > len(foci). Since Sequence.focus_order can be mutated after initialization (bypassing Sequence.__post_init__), this misses other invalid states (empty list, wrong length vs pulse_count, non-positive indices) and max([]) would crash with an unhelpful exception. Consider performing a full validation here (length, type/positivity, and bounds vs len(foci)) or calling a shared Sequence.validate_focus_order(num_foci=...) helper.
| focus_counts = solution.get_focus_counts() | ||
| focus_weights = xa.DataArray( | ||
| focus_counts / np.sum(focus_counts), | ||
| dims=("focal_point_index",), | ||
| coords={"focal_point_index": solution.simulation_result.coords["focal_point_index"]}, | ||
| ) | ||
| intensity = solution.simulation_result['intensity'] | ||
| intensity_aggregated = (intensity * focus_weights).sum(dim="focal_point_index", keep_attrs=True) | ||
| intensity_aggregated.attrs.update(intensity.attrs) |
There was a problem hiding this comment.
The new weighted intensity aggregation in calc_solution() is core to focus_order support, but there doesn't appear to be a unit test asserting that aggregation is weighted by focus_counts (and differs from the previous unweighted mean). Adding a focused test would help prevent regressions in the weighting logic.
Summary
Adds optional
Sequence.focus_ordersupport for flexible focal pulse assignment and integrates ISPTA balancing intoSolution.scale().Closes #449
Supersedes #450, which I accidentally closed while syncing my fork. Reopening here on a clean branch.
Changes
Focus ordering
Sequence.focus_orderattribute to define the focus index order for each pulse.Solution.get_focus_order()Solution.get_focus_counts()Protocol.calc_solution()so the old pulse-count divisibility check only applies whenfocus_orderis not provided.Protocol.calc_solution()to use a weighted mean based on focus hit counts instead of an unweighted mean.ISPTA balancing
Protocol.scaling_optionsattribute for scaling configurationscaling_optionsis passed through toSolution.scale()viasolution.scale(..., **self.scaling_options)Solution.scale()with optional balancing arguments:balance_methodbalance_metricorderingcompute_balanced_focus_counts()build_focus_order()focus_countsoverrides toanalyze()andget_ita()so balancing can compute baseline ISPTA using equal per-focus counts without depending on an existingSequence.focus_orderTesting
test_dict_undict_sequence()test_sequence_focus_order_validation()test_solution_get_focus_counts_with_explicit_focus_order()test_solution_build_focus_order_minimizes_repeats()test_solution_compute_balanced_focus_counts_preserves_total_and_weights_inverse_metric()test_to_dict_from_dict()test_calc_solution_skips_pulse_mismatch_when_focus_order_present()All existing tests pass
Pre-commit hooks pass