Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/sampleworks/utils/guidance_script_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ class GuidanceConfig:
guidance_start: int = -1
augmentation: bool = False
align_to_input: bool = False
alignment_reverse_diffusion: bool | None = None
recycling_steps: int | None = None
num_diffusion_steps: int = 200

Expand Down Expand Up @@ -341,6 +342,7 @@ def from_cli(
guidance_start=args.guidance_start,
augmentation=args.augmentation,
align_to_input=args.align_to_input,
alignment_reverse_diffusion=args.alignment_reverse_diffusion,
)

# __post_init__ already set defaults for model/guidance-specific
Expand Down Expand Up @@ -452,6 +454,17 @@ def add_generic_args(parser: argparse.ArgumentParser | GuidanceConfig):
action="store_true",
help="Enable alignment to input",
)
parser.add_argument(
"--alignment-reverse-diffusion",
action=argparse.BooleanOptionalAction,
default=None,
help=(
Comment on lines +458 to +461
"Align the noisy state to the denoised prediction during reverse "
"diffusion (described in Boltz-1 paper). Use "
"--no-alignment-reverse-diffusion to disable. Default: enabled for "
"Boltz, disabled for other models."
),
)
Comment thread
Copilot marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
parser.add_argument(
"--ensemble-size",
type=int,
Expand Down
11 changes: 9 additions & 2 deletions src/sampleworks/utils/guidance_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ def run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, device
return job_result


def _three_state_resolver(value: str | bool | None, default: bool) -> bool:
if value is None:
return default
return bool(value)


# "guidance_type" is also called "scaler" in many places
def _run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, device):
"""Run one configured guidance trajectory and save its outputs."""
Expand Down Expand Up @@ -479,8 +485,9 @@ def _run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, devic
else:
raise ValueError(f"Unknown model wrapper class: {wrapper_class_name}")

# Boltz was trained with this, others might not have been.
use_alignment_for_reverse_diffusion = is_boltz
use_alignment_for_reverse_diffusion = _three_state_resolver(
args.alignment_reverse_diffusion, is_boltz
)

# Create sampler with model-appropriate settings
sampler_config = EDMSamplerConfig(
Expand Down
19 changes: 19 additions & 0 deletions tests/cli/test_guidance_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,3 +570,22 @@ def test_invalid_preset_model_raises(self):
def test_invalid_preset_guidance_type_raises(self):
with pytest.raises(ValueError, match="Unknown guidance type"):
GuidanceConfig.from_cli(COMMON_ARGS, model="boltz1", guidance_type="typo")


class TestAlignmentReverseDiffusion:
"""--alignment-reverse-diffusion can be none, yes, or no and propagates into the config."""

BASE = ["--model", "boltz2", "--guidance-type", "pure_guidance"] + COMMON_ARGS

def test_omitted_defaults_to_none(self):
config = GuidanceConfig.from_cli(self.BASE)
assert config.alignment_reverse_diffusion is None

def test_enable_flag_sets_true(self):
config = GuidanceConfig.from_cli(self.BASE + ["--alignment-reverse-diffusion"])
assert config.alignment_reverse_diffusion is True

def test_disable_flag_sets_false(self):
"""--no-alignment-reverse-diffusion must be able to force the feature off."""
config = GuidanceConfig.from_cli(self.BASE + ["--no-alignment-reverse-diffusion"])
assert config.alignment_reverse_diffusion is False
17 changes: 17 additions & 0 deletions tests/utils/test_guidance_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from sampleworks.utils.guidance_script_arguments import GuidanceConfig, JobResult
from sampleworks.utils.guidance_script_utils import (
_three_state_resolver,
_write_job_metadata,
get_reward_function_and_structure,
save_everything,
Expand All @@ -16,6 +17,22 @@
from tests.utils.atom_array_builders import build_test_atom_array


@pytest.mark.parametrize(
"override, is_boltz, expected",
[
(None, True, True), # Boltz default: enabled
(None, False, False), # other models default: disabled
(True, False, True), # explicit opt-in on a non-Boltz model
(True, True, True), # explicit on, agrees with Boltz default
(False, True, False), # explicit opt-out overrides the Boltz default
(False, False, False), # explicit off, agrees with non-Boltz default
],
)
def test_resolve_alignment_reverse_diffusion(override, is_boltz, expected):
"""The override wins when set, None means is_boltz default."""
assert _three_state_resolver(override, is_boltz) is expected


def test_save_everything_uses_model_atom_array_for_mismatch(tmp_path: Path):
"""Mismatch final_state should save with model template when provided."""
refined_structure = {"asym_unit": build_test_atom_array(n_atoms=3, with_occupancy=True)}
Expand Down
Loading