From 63d5ee57a6314ff1a3b18d38009edd6625a71bdb Mon Sep 17 00:00:00 2001 From: HatPdotS Date: Thu, 18 Jun 2026 12:01:36 +0200 Subject: [PATCH] Added anomalous wavlength configuration and switching to the refine cli --- torchref/cli/_common.py | 13 ++++ torchref/cli/refine.py | 8 +++ torchref/refinement/base_refinement.py | 11 ++++ torchref/scaling/collection_scaler.py | 89 ++++++++++++++++++++++---- 4 files changed, 108 insertions(+), 13 deletions(-) diff --git a/torchref/cli/_common.py b/torchref/cli/_common.py index f87fe73..a0710d1 100644 --- a/torchref/cli/_common.py +++ b/torchref/cli/_common.py @@ -157,6 +157,19 @@ def add_adp_mode_arg(parser: argparse.ArgumentParser) -> None: ) +def add_wavelength_arg(parser: argparse.ArgumentParser) -> None: + """Add ``--wavelength`` argument (Angstroms; 0 disables anomalous).""" + parser.add_argument( + "--wavelength", + type=float, + default=1.0, + help="X-ray wavelength in Angstroms, used for anomalous (f'/f'') " + "scattering. Set to 0 to disable anomalous refinement entirely, which " + "also forces a Friedel-merged read of the data (no F(+)/F(-) Bijvoet " + "pairs). Default 1.0.", + ) + + def add_column_args(parser: argparse.ArgumentParser) -> None: """Add ``-csf`` and ``-csig`` column-selection arguments.""" parser.add_argument( diff --git a/torchref/cli/refine.py b/torchref/cli/refine.py index 8223f93..c820ae8 100644 --- a/torchref/cli/refine.py +++ b/torchref/cli/refine.py @@ -38,6 +38,7 @@ add_outdir_arg, add_output_format_args, add_single_model_args, + add_wavelength_arg, add_weights_arg, build_column_names, configure_unbuffered_output, @@ -136,6 +137,7 @@ def main(): "Only used when --with-rigid-body is set.", ) add_adp_mode_arg(refine_group) + add_wavelength_arg(refine_group) res = parser.add_argument_group("Resolution") add_dmin_arg(res) @@ -196,6 +198,10 @@ def main(): f"{args.anisotropic_selection or 'not resname HOH and not element H'})" ) print(adp_line) + if args.wavelength == 0: + print("Anomalous: off (wavelength 0 -> Friedel-merged read)") + else: + print(f"Wavelength: {args.wavelength:.4g} A") if manual_weights: print(f"Manual weights: {json.dumps(manual_weights)}") print("=" * 80) @@ -222,6 +228,7 @@ def main(): sigma_m_scale=args.sigma_m_scale, adp_mode=args.adp_mode, aniso_selection=args.anisotropic_selection, + wavelength=args.wavelength, ) # Apply manual group weights, if given. Merge onto DEFAULT_GROUP_WEIGHTS so @@ -316,6 +323,7 @@ def main(): "anisotropic_selection": ( args.anisotropic_selection if args.adp_mode == "anisotropic" else None ), + "wavelength": args.wavelength, "xray_mode": args.xray_mode, "sigma_m_scale": args.sigma_m_scale, "weights": manual_weights if manual_weights else None, diff --git a/torchref/refinement/base_refinement.py b/torchref/refinement/base_refinement.py index bfd9c30..997023a 100644 --- a/torchref/refinement/base_refinement.py +++ b/torchref/refinement/base_refinement.py @@ -138,6 +138,11 @@ def __init__( ``F(+)/F(-)`` (or ``I(+)/I(-)``) data are auto-detected and loaded as Friedel pairs when present, enabling the model's f'' term. True forces this; False forces a merged load (f'' disabled). + wavelength : float, optional + X-ray wavelength in Angstroms for the anomalous (f'/f'') scattering + correction. Default 1.0. A value of ``0`` means "no anomalous + refinement": it disables the correction (model wavelength ``None``) + and forces a Friedel-merged read (overrides ``anomalous`` to False). adp_mode : str, optional ADP parametrization: ``"isotropic"`` (default) refines a per-atom B-factor; ``"anisotropic"`` refines a 6-component U tensor for the @@ -175,6 +180,12 @@ def __init__( # model right after load, before scaling/restraints/targets. self.adp_mode = adp_mode self.aniso_selection = aniso_selection + # A wavelength of 0 means "no anomalous refinement": disable the f'/f'' + # correction (model wavelength None) and force a Friedel-merged read so + # F(+)/F(-) are not loaded as Bijvoet pairs. + if self.wavelength is not None and float(self.wavelength) == 0.0: + self.wavelength = None + self.anomalous = False # Persistent state and logger (created lazily) self._loss_state: Optional[LossState] = None diff --git a/torchref/scaling/collection_scaler.py b/torchref/scaling/collection_scaler.py index ed41943..71c7f75 100644 --- a/torchref/scaling/collection_scaler.py +++ b/torchref/scaling/collection_scaler.py @@ -20,6 +20,12 @@ import torch.nn as nn from torchref.base.metrics import get_rfactors, nll_xray +from torchref.base.reciprocal import get_scattering_vectors +from torchref.base.targets.xray_ml_sigmaa import ( + SigmaAEstimator, + epsilon_from_hkl, + ml_xray_loss_beta_math, +) from torchref.config import get_default_device, get_float_dtype from torchref.scaling.scaler_base import ScalerBase from torchref.scaling.solvent import SolventModel @@ -304,6 +310,7 @@ def refine_lbfgs_joint( lr: float = 1.0, max_iter: int = 200, history_size: int = 10, + scale_smoothness: float = 1000.0, verbose: bool = True, ) -> dict: """ @@ -336,10 +343,21 @@ def refine_lbfgs_joint( mc = self._model_collection all_keys = [mc.dark_key] + mc.timepoint_names - # Pre-compute all fcalc (detached) + # Pre-compute all fcalc (detached) plus the per-dataset σ_A model-error + # variance (beta/epsilon). beta is estimated ONCE on each dataset's free + # set from the currently-scaled |F_calc| and held constant (detached) + # during the scale optimisation — exactly as the single-dataset + # ``ScalerBase.refine_lbfgs`` does. The σ_A (Read MLF) likelihood is what + # keeps the per-bin log_scale from collapsing toward zero in weak shells; + # the plain σ-weighted Gaussian ``nll_xray`` used previously drove the + # scale down (model under-scaled, R blow-up). fcalc_cache = {} fractions_cache = {} data_cache = {} + beta_cache = {} + eps_cache = {} + work_cache = {} + centric_cache = {} for name in all_keys: if name not in dc: continue @@ -349,16 +367,35 @@ def refine_lbfgs_joint( fobs, sigma = data.get_corrected_data() rfree = data.rfree_flags with torch.no_grad(): - fc = model(hkl) - fcalc_cache[name] = fc.detach() - fractions_cache[name] = model.fractions.detach() + fc = model(hkl).detach() + fracs = model.fractions.detach() + f_sol_raw = self.get_mixed_solvent_raw(fracs) + scaled0 = super(CollectionScaler, self).forward( + fc, f_sol_override=f_sol_raw + ) + fc_amp0 = torch.abs(scaled0).reshape(-1) + fobs0 = fobs.to(fc_amp0.dtype).reshape(-1) + eps0 = epsilon_from_hkl( + hkl, getattr(data, "spacegroup", None) + ).to(fc_amp0.dtype) + s = get_scattering_vectors(hkl, data.cell) + dss0 = (torch.norm(s, dim=1) ** 2).to(fc_amp0.dtype) + beta0, eps0 = SigmaAEstimator().get( + fobs0, fc_amp0, data.centric, eps0, dss0, data.free.mask + ) + fcalc_cache[name] = fc + fractions_cache[name] = fracs data_cache[name] = (fobs, sigma, rfree) - - # Wrap the joint NLL + U-penalty as a LossState target. fcalc is - # detached, so the only leaves in the autograd graph are the - # scaler's own parameters — LossState's probe picks them up at - # registration time, and validate_loss inside state.step handles - # NaN/Inf rejection so no per-target try/except is needed. + beta_cache[name] = beta0 + eps_cache[name] = eps0 + work_cache[name] = data.work + centric_cache[name] = data.centric + + # Wrap the joint σ_A ML loss + U-penalty as a LossState target. fcalc is + # detached, so the only leaves in the autograd graph are the scaler's own + # parameters — LossState's probe picks them up at registration time, and + # validate_loss inside state.step handles NaN/Inf rejection so no + # per-target try/except is needed. scaler_self = self class _CollectionScalerJointTarget(nn.Module): @@ -372,19 +409,45 @@ def forward(self): continue fc = fcalc_cache[nm] fracs = fractions_cache[nm] - fobs_n, sigma_n, _ = data_cache[nm] f_sol_raw = scaler_self.get_mixed_solvent_raw(fracs) scaled = super(CollectionScaler, scaler_self).forward( fc, f_sol_override=f_sol_raw ) - loss = nll_xray(fobs_n, scaled, sigma_n) + # σ_A (Read MLF) scale-fit on the WORK set, with detached + # free-set beta/epsilon — same likelihood the body + # refinement uses. + amp = torch.abs(scaled).reshape(-1) + work = work_cache[nm] + F_obs = work.F.to(amp.dtype) + Fc = work.select(amp) + beta_w = work.select(beta_cache[nm]).to(F_obs.dtype) + eps_w = ( + work.select(eps_cache[nm]).to(F_obs.dtype) + if eps_cache[nm] is not None + else None + ) + centric_w = work.select(centric_cache[nm]) + loss = ml_xray_loss_beta_math( + F_obs, Fc, beta_w, centric_w, epsilon=eps_w + ) if torch.isfinite(loss): total = total + loss n += 1 if n > 0: total = total / n u_penalty = torch.sum(scaler_self.U**2) - return total + u_penalty + # Smoothness (Tikhonov) penalty on the per-bin log_scale: the + # lowest-resolution shells are solvent-dominated with large + # model error, so the σ_A likelihood is nearly flat there and a + # free per-bin scale runs away to -inf (model amplitudes -> 0, + # R blow-up). Penalising squared first differences ties each bin + # to its neighbours so under-constrained shells follow the + # well-determined ones instead of collapsing. + ls = scaler_self.log_scale + smooth_penalty = scale_smoothness * torch.sum( + (ls[1:] - ls[:-1]) ** 2 + ) + return total + u_penalty + smooth_penalty state = LossState(device=self.device) state.register_target("scaler/joint", _CollectionScalerJointTarget())