diff --git a/src/sampleworks/cli/guidance.py b/src/sampleworks/cli/guidance.py index 1066c20f..3e95885e 100644 --- a/src/sampleworks/cli/guidance.py +++ b/src/sampleworks/cli/guidance.py @@ -6,17 +6,22 @@ from loguru import logger from sampleworks.utils.guidance_script_arguments import GuidanceConfig -from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance def main(argv: list[str] | None = None) -> int: config = GuidanceConfig.from_cli(argv) + + from loguru import logger + + from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance + logger.info(f"Running guidance with config: {config}") device, model_wrapper = get_model_and_device( config.device, getattr(config, "model_checkpoint", None), config.model, method=getattr(config, "method", None), + protpardelle_config_path=getattr(config, "protpardelle_config_path", None), ) result = run_guidance(config, config.guidance_type, model_wrapper, device) return result.exit_code diff --git a/src/sampleworks/core/samplers/edm.py b/src/sampleworks/core/samplers/edm.py index 1f81c551..b0be24b6 100644 --- a/src/sampleworks/core/samplers/edm.py +++ b/src/sampleworks/core/samplers/edm.py @@ -422,8 +422,10 @@ def step( # t_hat will be float if check_context didn't raise # Use no_grad when gradients aren't needed to avoid memory overhead from # gradient checkpointing holding intermediate activations + # TODO testing adding eps to signature for use with Protpardelle-1c, if successful, + # I need to modify the Protocol itself. with torch.set_grad_enabled(allow_gradients): - x_hat_0 = model_wrapper.step(noisy_state, t_hat, features=features) + x_hat_0 = model_wrapper.step(noisy_state, t_hat, eps, features=features) reconciler = ( context.reconciler.to(torch.as_tensor(x_hat_0).device) diff --git a/src/sampleworks/models/protpardelle/wrapper.py b/src/sampleworks/models/protpardelle/wrapper.py index 0417ad6a..c6468438 100644 --- a/src/sampleworks/models/protpardelle/wrapper.py +++ b/src/sampleworks/models/protpardelle/wrapper.py @@ -383,7 +383,9 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtpardelleConditi # Concatenate per-chain aatypes in chain order; chains are placed # contiguously at the front of the padded sequence by the helper above. - chain_aatypes = [seq_to_aatype(seq, num_tokens=NUM_AATYPE_TOKENS) for seq in sequences] + chain_aatypes = [ + seq_to_aatype(seq, num_tokens=NUM_AATYPE_TOKENS) for seq in sequences # ty: ignore + ] flat_aatype = torch.cat(chain_aatypes).to(self.device) padded_len = seq_mask.shape[1] aatype = torch.zeros((1, padded_len), dtype=torch.long, device=self.device) @@ -423,7 +425,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtpardelleConditi atom37_atom_index=atom37_atom_index, sampling_kwargs=sampling_kwargs, sequences=tuple(sequences), - x_self_conditioning=None, + x_self_conditioning=None ) # x_init is a shape-compatible reference drawn from a Gaussian prior. @@ -437,7 +439,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[ProtpardelleConditi def _atom37_indices_from_atom_array( self, atom_array - ) -> tuple[Tensor, Tensor]: + ) -> tuple[Int[Tensor, "atoms"], Int[Tensor, "atoms"]]: """Derive per-atom atom37 destination indices from an Atomworks atom array. For each atom in ``atom_array`` (the order the sampler's flat ``x_t`` @@ -634,6 +636,7 @@ def step( self, x_t: Float[Tensor, "batch atoms 3"], t: Float[Tensor, "*batch"] | float, + sigma_float: float, *, features: GenerativeModelInput[ProtpardelleConditioning] | None = None, ) -> Float[Tensor, "batch atoms 3"]: @@ -642,16 +645,16 @@ def step( protpardelle-1c/src/protpardelle/core/models.py:L1760 (commit ee378400f25b801fa481028000f9060183d7fb4c on branch main) - The entire reverse-diffusion loop runs internally; the returned tensor - is the final all-atom prediction, flattened to the atoms implied by the - input sequence (the ``atom_mask`` in the conditioning). + The returned tensor is the final all-atom prediction, flattened to the + atoms implied by the input sequence (the ``atom_mask`` in the + conditioning). Parameters ---------- x_t : Float[Tensor, "batch atoms 3"] Noisy structure at timestep :math:`t`. t : Float[Tensor, "*batch"] | float - Current timestep/noise level (:math:`\\hat{t}` from EDM schedule). + Current timestep/noise level (:math:`\hat{t}` from EDM schedule). features : GenerativeModelInput[BoltzConditioning] | None Model features as returned by ``featurize``. diff --git a/src/sampleworks/utils/guidance_script_arguments.py b/src/sampleworks/utils/guidance_script_arguments.py index f18c542f..0364b3e5 100644 --- a/src/sampleworks/utils/guidance_script_arguments.py +++ b/src/sampleworks/utils/guidance_script_arguments.py @@ -197,6 +197,15 @@ def validate_model_checkpoint( "msa_path", "disable_chiral_features", "track_chiral_features", + "protpardelle_config_path", + "protpardelle_s_churn", + "protpardelle_step_scale", + "protpardelle_sidechain_mode", + "protpardelle_skip_mpnn_proportion", + "protpardelle_jump_steps", + "protpardelle_uniform_steps", + "protpardelle_temperature", + "protpardelle_top_p", # generic (overridable) "ensemble_size", "recycling_steps", @@ -612,7 +621,62 @@ def add_protpardelle_specific_args(parser: argparse.ArgumentParser | GuidanceCon "--model-checkpoint", type=str, default=None, - help="Path to Protpardelle checkpoint (default: auto-resolved from /checkpoints/ or pixi env)", + help=( + "Path to Protpardelle checkpoint " + "(default: auto-resolved from /checkpoints/ or pixi env)" + ), + ) + parser.add_argument( + "--protpardelle-config-path", + type=str, + default=None, + help="Path to the Protpardelle model config YAML (default: bundled cc89 config)", + ) + parser.add_argument( + "--protpardelle-s-churn", + type=float, + default=40.0, + help="Protpardelle stochasticity parameter forwarded as s_churn", + ) + parser.add_argument( + "--protpardelle-step-scale", + type=float, + default=1.0, + help="Protpardelle score inverse-temperature scale forwarded as step_scale", + ) + parser.add_argument( + "--protpardelle-sidechain-mode", + action="store_true", + help="Enable Protpardelle all-atom MiniMPNN side-chain co-design", + ) + parser.add_argument( + "--protpardelle-skip-mpnn-proportion", + type=float, + default=1.0, + help="Fraction of denoising steps to skip MiniMPNN at the start", + ) + parser.add_argument( + "--protpardelle-jump-steps", + action="store_true", + help="Use Protpardelle superposition jump-step sampling", + ) + parser.add_argument( + "--protpardelle-uniform-steps", + action=argparse.BooleanOptionalAction, + default=True, + help="Use Protpardelle uniform-step sampling (default: enabled)", + ) + parser.add_argument( + "--protpardelle-temperature", + type=float, + default=1.0, + help="Temperature applied to Protpardelle aatype logits", + ) + parser.add_argument( + "--protpardelle-top-p", + type=float, + default=1.0, + help="Top-p truncation for Protpardelle aatype sampling", ) diff --git a/src/sampleworks/utils/guidance_script_utils.py b/src/sampleworks/utils/guidance_script_utils.py index 2a45448d..1a6905b6 100644 --- a/src/sampleworks/utils/guidance_script_utils.py +++ b/src/sampleworks/utils/guidance_script_utils.py @@ -63,7 +63,7 @@ try: from sampleworks.models.protpardelle.wrapper import ProtpardelleWrapper except ImportError: - ProtpardelleWrapper = None + ProtpardelleWrapper = None # ty:ignore[invalid-assignment] logger.warning("Failed to import Protpardelle, hopefully you're running a different model") from sampleworks.utils.torch_utils import try_gpu @@ -174,6 +174,7 @@ def get_model_and_device( model_type: str, method: str | None = None, model: Any = None, + protpardelle_config_path: str | None = None, ) -> tuple[torch.device, Any]: """Validate a checkpoint, choose a device, and construct the model wrapper.""" validated_checkpoint_path = validate_model_checkpoint(model_type, model_checkpoint_path) @@ -224,8 +225,9 @@ def get_model_and_device( if ProtpardelleWrapper is None: raise ImportError("Protpardelle dependencies not installed") logger.debug(f"Loading Protpardelle model from {validated_checkpoint_path}") + config_path = protpardelle_config_path or "src/sampleworks/data/cc89_epoch415.yaml" model_wrapper = ProtpardelleWrapper( - config_path=str(Path("src/sampleworks/data/cc89_epoch415.yaml").expanduser().resolve()), + config_path=str(Path(config_path).expanduser().resolve()), checkpoint_path=validated_checkpoint_path, device=device, ) @@ -501,11 +503,18 @@ def _run_guidance(args: GuidanceConfig, guidance_type: str, model_wrapper, devic elif "Protpardelle" in wrapper_class_name: from sampleworks.models.protpardelle.wrapper import annotate_structure_for_protpardelle - # TODO: this is where we need to pass in things like step scale, s_churn, etc... - # I'm not entirely sure what all the args are yet though. structure = annotate_structure_for_protpardelle( structure, - ensemble_size=args.ensemble_size + ensemble_size=args.ensemble_size, + num_steps=args.num_diffusion_steps, + s_churn=getattr(args, "protpardelle_s_churn", 40.0), + step_scale=getattr(args, "protpardelle_step_scale", 1.0), + sidechain_mode=getattr(args, "protpardelle_sidechain_mode", False), + skip_mpnn_proportion=getattr(args, "protpardelle_skip_mpnn_proportion", 1.0), + jump_steps=getattr(args, "protpardelle_jump_steps", False), + uniform_steps=getattr(args, "protpardelle_uniform_steps", True), + temperature=getattr(args, "protpardelle_temperature", 1.0), + top_p=getattr(args, "protpardelle_top_p", 1.0), ) edm_sampler_kwargs = { "s_max": 80, "s_min": 0.001, "gamma_0": 0.08, "gamma_min": 0.00, @@ -725,6 +734,7 @@ def run_guidance_job_queue(job_queue_path: str) -> list[JobResult]: template_job.model_checkpoint, template_job.model, # this is not actually the model, it's the model name, e.g. boltz2 method=template_job.method if hasattr(template_job, "method") else None, + protpardelle_config_path=getattr(template_job, "protpardelle_config_path", None), ) job_results = [] for i, job in enumerate(job_queue): diff --git a/tests/cli/test_guidance_cli.py b/tests/cli/test_guidance_cli.py index 0272009e..d905beb3 100644 --- a/tests/cli/test_guidance_cli.py +++ b/tests/cli/test_guidance_cli.py @@ -66,7 +66,46 @@ def test_model_specific_args_rf3_msa(self): "/data/msa.a3m", ] + COMMON_ARGS config = GuidanceConfig.from_cli(argv) - assert config.msa_path == "/data/msa.a3m" # ty: ignore[unresolved-attribute] + assert getattr(config, "msa_path") == "/data/msa.a3m" + + def test_model_specific_args_protpardelle_sampling(self): + argv = [ + "--model", + "protpardelle", + "--guidance-type", + "pure_guidance", + "--model-checkpoint", + "/data/cc89.pth", + "--protpardelle-config-path", + "/data/cc89.yaml", + "--protpardelle-s-churn", + "2.5", + "--protpardelle-step-scale", + "0.75", + "--protpardelle-sidechain-mode", + "--protpardelle-skip-mpnn-proportion", + "0.25", + "--protpardelle-jump-steps", + "--no-protpardelle-uniform-steps", + "--protpardelle-temperature", + "0.9", + "--protpardelle-top-p", + "0.8", + "--num-diffusion-steps", + "64", + ] + COMMON_ARGS + config = GuidanceConfig.from_cli(argv) + assert getattr(config, "model_checkpoint") == "/data/cc89.pth" + assert getattr(config, "protpardelle_config_path") == "/data/cc89.yaml" + assert getattr(config, "protpardelle_s_churn") == 2.5 + assert getattr(config, "protpardelle_step_scale") == 0.75 + assert getattr(config, "protpardelle_sidechain_mode") is True + assert getattr(config, "protpardelle_skip_mpnn_proportion") == 0.25 + assert getattr(config, "protpardelle_jump_steps") is True + assert getattr(config, "protpardelle_uniform_steps") is False + assert getattr(config, "protpardelle_temperature") == 0.9 + assert getattr(config, "protpardelle_top_p") == 0.8 + assert config.num_diffusion_steps == 64 def test_guidance_specific_args_fk(self): argv = [ diff --git a/tests/models/protpardelle/test_protpardelle_wrapper.py b/tests/models/protpardelle/test_protpardelle_wrapper.py index f3c9999c..c31c9b84 100644 --- a/tests/models/protpardelle/test_protpardelle_wrapper.py +++ b/tests/models/protpardelle/test_protpardelle_wrapper.py @@ -143,6 +143,34 @@ def test_defaults_match_ai_allatom_recipe(self): assert config.jump_steps is False assert config.sidechain_mode is False + def test_adds_sampling_options(self): + structure = _protein_structure(SEQ_A) + annotated = annotate_structure_for_protpardelle( + structure, + ensemble_size=2, + num_steps=64, + s_churn=2.5, + step_scale=0.75, + sidechain_mode=True, + skip_mpnn_proportion=0.25, + jump_steps=True, + uniform_steps=False, + temperature=0.9, + top_p=0.8, + ) + + config = annotated["_protpardelle_config"] + assert config.ensemble_size == 2 + assert config.num_steps == 64 + assert config.s_churn == 2.5 + assert config.step_scale == 0.75 + assert config.sidechain_mode is True + assert config.skip_mpnn_proportion == 0.25 + assert config.jump_steps is True + assert config.uniform_steps is False + assert config.temperature == 0.9 + assert config.top_p == 0.8 + class TestBuildSamplingKwargs: def test_defaults(self):