From f07c7ab34cdd26917eef0c4ea9d7c82d7ec2f762 Mon Sep 17 00:00:00 2001 From: Karson Chrispens <33336327+k-chrispens@users.noreply.github.com> Date: Tue, 5 May 2026 21:07:24 +0000 Subject: [PATCH 1/2] feat(guidance): add --alignment-reverse-diffusion arg --- src/sampleworks/utils/guidance_script_arguments.py | 13 +++++++++++++ src/sampleworks/utils/guidance_script_utils.py | 6 ++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index d2cd610..4c2b8a2 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -228,6 +228,7 @@ class GuidanceConfig: guidance_start: int = -1 augmentation: bool = False align_to_input: bool = False + alignment_reverse_diffusion: bool | None = None recycling_steps: int | None = None num_diffusion_steps: int = 200 @@ -341,6 +342,7 @@ def from_cli( guidance_start=args.guidance_start, augmentation=args.augmentation, align_to_input=args.align_to_input, + alignment_reverse_diffusion=args.alignment_reverse_diffusion, ) # __post_init__ already set defaults for model/guidance-specific @@ -452,6 +454,17 @@ def add_generic_args(parser: argparse.ArgumentParser | GuidanceConfig): action="store_true", help="Enable alignment to input", ) + parser.add_argument( + "--alignment-reverse-diffusion", + action="store_const", + const=True, + default=None, + help=( + "Enable alignment of the noisy state to the denoised prediction " + "during reverse diffusion (described in Boltz-1 paper). Default: enabled for Boltz, " + "disabled for other models." + ), + ) parser.add_argument( "--ensemble-size", type=int, diff --git a/src/sampleworks/utils/guidance_script_utils.py b/src/sampleworks/utils/guidance_script_utils.py index 0b4d60b..cd283d2 100644 --- a/src/sampleworks/utils/guidance_script_utils.py +++ b/src/sampleworks/utils/guidance_script_utils.py @@ -479,8 +479,10 @@ def _run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, devic else: raise ValueError(f"Unknown model wrapper class: {wrapper_class_name}") - # Boltz was trained with this, others might not have been. - use_alignment_for_reverse_diffusion = is_boltz + # Boltz was trained with this, other models default to disabled, but the user + # can opt in via --alignment-reverse-diffusion. + override = getattr(args, "alignment_reverse_diffusion", None) + use_alignment_for_reverse_diffusion = is_boltz if override is None else override # Create sampler with model-appropriate settings sampler_config = EDMSamplerConfig( From 578c8d21806abe1a35a977e0620a2ebfe36cbf6b Mon Sep 17 00:00:00 2001 From: Karson Chrispens <33336327+k-chrispens@users.noreply.github.com> Date: Sun, 21 Jun 2026 21:49:13 -0400 Subject: [PATCH 2/2] fix: addressing Marcus's comments --- .../utils/guidance_script_arguments.py | 10 +++++----- .../utils/guidance_script_utils.py | 13 +++++++++---- tests/cli/test_guidance_cli.py | 19 +++++++++++++++++++ tests/utils/test_guidance_script_utils.py | 17 +++++++++++++++++ 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index 4c2b8a2..45f0686 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -456,13 +456,13 @@ def add_generic_args(parser: argparse.ArgumentParser | GuidanceConfig): ) parser.add_argument( "--alignment-reverse-diffusion", - action="store_const", - const=True, + action=argparse.BooleanOptionalAction, default=None, help=( - "Enable alignment of the noisy state to the denoised prediction " - "during reverse diffusion (described in Boltz-1 paper). Default: enabled for Boltz, " - "disabled for other models." + "Align the noisy state to the denoised prediction during reverse " + "diffusion (described in Boltz-1 paper). Use " + "--no-alignment-reverse-diffusion to disable. Default: enabled for " + "Boltz, disabled for other models." ), ) parser.add_argument( diff --git a/src/sampleworks/utils/guidance_script_utils.py b/src/sampleworks/utils/guidance_script_utils.py index cd283d2..f477ab7 100644 --- a/src/sampleworks/utils/guidance_script_utils.py +++ b/src/sampleworks/utils/guidance_script_utils.py @@ -423,6 +423,12 @@ def run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, device return job_result +def _three_state_resolver(value: str | bool | None, default: bool) -> bool: + if value is None: + return default + return bool(value) + + # "guidance_type" is also called "scaler" in many places def _run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, device): """Run one configured guidance trajectory and save its outputs.""" @@ -479,10 +485,9 @@ def _run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, devic else: raise ValueError(f"Unknown model wrapper class: {wrapper_class_name}") - # Boltz was trained with this, other models default to disabled, but the user - # can opt in via --alignment-reverse-diffusion. - override = getattr(args, "alignment_reverse_diffusion", None) - use_alignment_for_reverse_diffusion = is_boltz if override is None else override + use_alignment_for_reverse_diffusion = _three_state_resolver( + args.alignment_reverse_diffusion, is_boltz + ) # Create sampler with model-appropriate settings sampler_config = EDMSamplerConfig( diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py index 5e5cf6c..0272009 100644 --- a/tests/cli/test_guidance_cli.py +++ b/tests/cli/test_guidance_cli.py @@ -570,3 +570,22 @@ def test_invalid_preset_model_raises(self): def test_invalid_preset_guidance_type_raises(self): with pytest.raises(ValueError, match="Unknown guidance type"): GuidanceConfig.from_cli(COMMON_ARGS, model="boltz1", guidance_type="typo") + + +class TestAlignmentReverseDiffusion: + """--alignment-reverse-diffusion can be none, yes, or no and propagates into the config.""" + + BASE = ["--model", "boltz2", "--guidance-type", "pure_guidance"] + COMMON_ARGS + + def test_omitted_defaults_to_none(self): + config = GuidanceConfig.from_cli(self.BASE) + assert config.alignment_reverse_diffusion is None + + def test_enable_flag_sets_true(self): + config = GuidanceConfig.from_cli(self.BASE + ["--alignment-reverse-diffusion"]) + assert config.alignment_reverse_diffusion is True + + def test_disable_flag_sets_false(self): + """--no-alignment-reverse-diffusion must be able to force the feature off.""" + config = GuidanceConfig.from_cli(self.BASE + ["--no-alignment-reverse-diffusion"]) + assert config.alignment_reverse_diffusion is False diff --git a/tests/utils/test_guidance_script_utils.py b/tests/utils/test_guidance_script_utils.py index 2b1829f..a91a0fa 100644 --- a/tests/utils/test_guidance_script_utils.py +++ b/tests/utils/test_guidance_script_utils.py @@ -8,6 +8,7 @@ import torch from sampleworks.utils.guidance_script_arguments import GuidanceConfig, JobResult from sampleworks.utils.guidance_script_utils import ( + _three_state_resolver, _write_job_metadata, get_reward_function_and_structure, save_everything, @@ -16,6 +17,22 @@ from tests.utils.atom_array_builders import build_test_atom_array +@pytest.mark.parametrize( + "override, is_boltz, expected", + [ + (None, True, True), # Boltz default: enabled + (None, False, False), # other models default: disabled + (True, False, True), # explicit opt-in on a non-Boltz model + (True, True, True), # explicit on, agrees with Boltz default + (False, True, False), # explicit opt-out overrides the Boltz default + (False, False, False), # explicit off, agrees with non-Boltz default + ], +) +def test_resolve_alignment_reverse_diffusion(override, is_boltz, expected): + """The override wins when set, None means is_boltz default.""" + assert _three_state_resolver(override, is_boltz) is expected + + def test_save_everything_uses_model_atom_array_for_mismatch(tmp_path: Path): """Mismatch final_state should save with model template when provided.""" refined_structure = {"asym_unit": build_test_atom_array(n_atoms=3, with_occupancy=True)}