Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions torchref/cli/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,19 @@ def add_adp_mode_arg(parser: argparse.ArgumentParser) -> None:
)


def add_wavelength_arg(parser: argparse.ArgumentParser) -> None:
"""Add ``--wavelength`` argument (Angstroms; 0 disables anomalous)."""
parser.add_argument(
"--wavelength",
type=float,
default=1.0,
help="X-ray wavelength in Angstroms, used for anomalous (f'/f'') "
"scattering. Set to 0 to disable anomalous refinement entirely, which "
"also forces a Friedel-merged read of the data (no F(+)/F(-) Bijvoet "
"pairs). Default 1.0.",
)


def add_column_args(parser: argparse.ArgumentParser) -> None:
"""Add ``-csf`` and ``-csig`` column-selection arguments."""
parser.add_argument(
Expand Down
8 changes: 8 additions & 0 deletions torchref/cli/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
add_outdir_arg,
add_output_format_args,
add_single_model_args,
add_wavelength_arg,
add_weights_arg,
build_column_names,
configure_unbuffered_output,
Expand Down Expand Up @@ -136,6 +137,7 @@ def main():
"Only used when --with-rigid-body is set.",
)
add_adp_mode_arg(refine_group)
add_wavelength_arg(refine_group)

res = parser.add_argument_group("Resolution")
add_dmin_arg(res)
Expand Down Expand Up @@ -196,6 +198,10 @@ def main():
f"{args.anisotropic_selection or 'not resname HOH and not element H'})"
)
print(adp_line)
if args.wavelength == 0:
print("Anomalous: off (wavelength 0 -> Friedel-merged read)")
else:
print(f"Wavelength: {args.wavelength:.4g} A")
if manual_weights:
print(f"Manual weights: {json.dumps(manual_weights)}")
print("=" * 80)
Expand All @@ -222,6 +228,7 @@ def main():
sigma_m_scale=args.sigma_m_scale,
adp_mode=args.adp_mode,
aniso_selection=args.anisotropic_selection,
wavelength=args.wavelength,
)

# Apply manual group weights, if given. Merge onto DEFAULT_GROUP_WEIGHTS so
Expand Down Expand Up @@ -316,6 +323,7 @@ def main():
"anisotropic_selection": (
args.anisotropic_selection if args.adp_mode == "anisotropic" else None
),
"wavelength": args.wavelength,
"xray_mode": args.xray_mode,
"sigma_m_scale": args.sigma_m_scale,
"weights": manual_weights if manual_weights else None,
Expand Down
11 changes: 11 additions & 0 deletions torchref/refinement/base_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def __init__(
``F(+)/F(-)`` (or ``I(+)/I(-)``) data are auto-detected and loaded as
Friedel pairs when present, enabling the model's f'' term. True forces
this; False forces a merged load (f'' disabled).
wavelength : float, optional
X-ray wavelength in Angstroms for the anomalous (f'/f'') scattering
correction. Default 1.0. A value of ``0`` means "no anomalous
refinement": it disables the correction (model wavelength ``None``)
and forces a Friedel-merged read (overrides ``anomalous`` to False).
adp_mode : str, optional
ADP parametrization: ``"isotropic"`` (default) refines a per-atom
B-factor; ``"anisotropic"`` refines a 6-component U tensor for the
Expand Down Expand Up @@ -175,6 +180,12 @@ def __init__(
# model right after load, before scaling/restraints/targets.
self.adp_mode = adp_mode
self.aniso_selection = aniso_selection
# A wavelength of 0 means "no anomalous refinement": disable the f'/f''
# correction (model wavelength None) and force a Friedel-merged read so
# F(+)/F(-) are not loaded as Bijvoet pairs.
if self.wavelength is not None and float(self.wavelength) == 0.0:
self.wavelength = None
self.anomalous = False

# Persistent state and logger (created lazily)
self._loss_state: Optional[LossState] = None
Expand Down
89 changes: 76 additions & 13 deletions torchref/scaling/collection_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
import torch.nn as nn

from torchref.base.metrics import get_rfactors, nll_xray
from torchref.base.reciprocal import get_scattering_vectors
from torchref.base.targets.xray_ml_sigmaa import (
SigmaAEstimator,
epsilon_from_hkl,
ml_xray_loss_beta_math,
)
from torchref.config import get_default_device, get_float_dtype
from torchref.scaling.scaler_base import ScalerBase
from torchref.scaling.solvent import SolventModel
Expand Down Expand Up @@ -304,6 +310,7 @@ def refine_lbfgs_joint(
lr: float = 1.0,
max_iter: int = 200,
history_size: int = 10,
scale_smoothness: float = 1000.0,
verbose: bool = True,
) -> dict:
"""
Expand Down Expand Up @@ -336,10 +343,21 @@ def refine_lbfgs_joint(
mc = self._model_collection
all_keys = [mc.dark_key] + mc.timepoint_names

# Pre-compute all fcalc (detached)
# Pre-compute all fcalc (detached) plus the per-dataset σ_A model-error
# variance (beta/epsilon). beta is estimated ONCE on each dataset's free
# set from the currently-scaled |F_calc| and held constant (detached)
# during the scale optimisation — exactly as the single-dataset
# ``ScalerBase.refine_lbfgs`` does. The σ_A (Read MLF) likelihood is what
# keeps the per-bin log_scale from collapsing toward zero in weak shells;
# the plain σ-weighted Gaussian ``nll_xray`` used previously drove the
# scale down (model under-scaled, R blow-up).
fcalc_cache = {}
fractions_cache = {}
data_cache = {}
beta_cache = {}
eps_cache = {}
work_cache = {}
centric_cache = {}
for name in all_keys:
if name not in dc:
continue
Expand All @@ -349,16 +367,35 @@ def refine_lbfgs_joint(
fobs, sigma = data.get_corrected_data()
rfree = data.rfree_flags
with torch.no_grad():
fc = model(hkl)
fcalc_cache[name] = fc.detach()
fractions_cache[name] = model.fractions.detach()
fc = model(hkl).detach()
fracs = model.fractions.detach()
f_sol_raw = self.get_mixed_solvent_raw(fracs)
scaled0 = super(CollectionScaler, self).forward(
fc, f_sol_override=f_sol_raw
)
fc_amp0 = torch.abs(scaled0).reshape(-1)
fobs0 = fobs.to(fc_amp0.dtype).reshape(-1)
eps0 = epsilon_from_hkl(
hkl, getattr(data, "spacegroup", None)
).to(fc_amp0.dtype)
s = get_scattering_vectors(hkl, data.cell)
dss0 = (torch.norm(s, dim=1) ** 2).to(fc_amp0.dtype)
beta0, eps0 = SigmaAEstimator().get(
fobs0, fc_amp0, data.centric, eps0, dss0, data.free.mask
)
fcalc_cache[name] = fc
fractions_cache[name] = fracs
data_cache[name] = (fobs, sigma, rfree)

# Wrap the joint NLL + U-penalty as a LossState target. fcalc is
# detached, so the only leaves in the autograd graph are the
# scaler's own parameters — LossState's probe picks them up at
# registration time, and validate_loss inside state.step handles
# NaN/Inf rejection so no per-target try/except is needed.
beta_cache[name] = beta0
eps_cache[name] = eps0
work_cache[name] = data.work
centric_cache[name] = data.centric

# Wrap the joint σ_A ML loss + U-penalty as a LossState target. fcalc is
# detached, so the only leaves in the autograd graph are the scaler's own
# parameters — LossState's probe picks them up at registration time, and
# validate_loss inside state.step handles NaN/Inf rejection so no
# per-target try/except is needed.
scaler_self = self

class _CollectionScalerJointTarget(nn.Module):
Expand All @@ -372,19 +409,45 @@ def forward(self):
continue
fc = fcalc_cache[nm]
fracs = fractions_cache[nm]
fobs_n, sigma_n, _ = data_cache[nm]
f_sol_raw = scaler_self.get_mixed_solvent_raw(fracs)
scaled = super(CollectionScaler, scaler_self).forward(
fc, f_sol_override=f_sol_raw
)
loss = nll_xray(fobs_n, scaled, sigma_n)
# σ_A (Read MLF) scale-fit on the WORK set, with detached
# free-set beta/epsilon — same likelihood the body
# refinement uses.
amp = torch.abs(scaled).reshape(-1)
work = work_cache[nm]
F_obs = work.F.to(amp.dtype)
Fc = work.select(amp)
beta_w = work.select(beta_cache[nm]).to(F_obs.dtype)
eps_w = (
work.select(eps_cache[nm]).to(F_obs.dtype)
if eps_cache[nm] is not None
else None
)
centric_w = work.select(centric_cache[nm])
loss = ml_xray_loss_beta_math(
F_obs, Fc, beta_w, centric_w, epsilon=eps_w
)
if torch.isfinite(loss):
total = total + loss
n += 1
if n > 0:
total = total / n
u_penalty = torch.sum(scaler_self.U**2)
return total + u_penalty
# Smoothness (Tikhonov) penalty on the per-bin log_scale: the
# lowest-resolution shells are solvent-dominated with large
# model error, so the σ_A likelihood is nearly flat there and a
# free per-bin scale runs away to -inf (model amplitudes -> 0,
# R blow-up). Penalising squared first differences ties each bin
# to its neighbours so under-constrained shells follow the
# well-determined ones instead of collapsing.
ls = scaler_self.log_scale
smooth_penalty = scale_smoothness * torch.sum(
(ls[1:] - ls[:-1]) ** 2
)
return total + u_penalty + smooth_penalty

state = LossState(device=self.device)
state.register_target("scaler/joint", _CollectionScalerJointTarget())
Expand Down
Loading