diff --git a/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/sf.py b/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/sf.py index 6a1a1228..a9e4a607 100644 --- a/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/sf.py +++ b/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/sf.py @@ -1,4 +1,3 @@ -# ty: ignore # ruff: ignore """Cromer-Mann coefficents and related constants diff --git a/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/transformer.py b/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/transformer.py index 90f92cae..d8367179 100644 --- a/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/transformer.py +++ b/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/transformer.py @@ -1,5 +1,3 @@ -# ty: ignore - """Minimal components of qFit's transformer module necessary to implement qFit's volume.py diff --git a/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/unitcell.py b/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/unitcell.py index 582d2beb..23988e6b 100644 --- a/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/unitcell.py +++ b/src/sampleworks/core/forward_models/xray/real_space_density_deps/qfit/unitcell.py @@ -1,6 +1,3 @@ -# type: ignore -# ty: ignore - """Classes for handling unit cell transformation.""" from itertools import product diff --git a/src/sampleworks/core/samplers/edm.py b/src/sampleworks/core/samplers/edm.py index 81456170..7b94fdaa 100644 --- a/src/sampleworks/core/samplers/edm.py +++ b/src/sampleworks/core/samplers/edm.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import cast, TYPE_CHECKING import einx import torch @@ -395,7 +395,8 @@ def step( t_hat = context.t_effective dt = context.dt - eps_scale = context.noise_scale + # check_context() guarantees noise_scale is present for sampling steps. + eps_scale = cast(float, context.noise_scale) allow_gradients = True if scaler and getattr(scaler, "requires_gradients", False) else False centroid = einx.mean("... [n] c", state) @@ -415,7 +416,7 @@ def step( # Store eps separately for proper frame transformation # eps_scale will be float if check_context didn't raise - eps = torch.randn_like(maybe_augmented_state) * eps_scale # ty: ignore[unsupported-operator] + eps = torch.randn_like(maybe_augmented_state) * eps_scale noisy_state = maybe_augmented_state + eps noisy_state = torch.as_tensor(noisy_state).detach().requires_grad_(allow_gradients) @@ -507,7 +508,7 @@ def step( # multiple particles. # Only compute when noise_var > 0 to avoid division by near-zero # (matching Boltz behavior) - noise_var = eps_scale**2 # ty: ignore[unsupported-operator] + noise_var = eps_scale**2 if noise_var > 0: log_proposal_correction = einx.sum( "... [b n c]", eps_working_frame**2 - (eps_working_frame + proposal_shift) ** 2 diff --git a/src/sampleworks/metrics/lddt.py b/src/sampleworks/metrics/lddt.py index 8602e755..ca49617f 100644 --- a/src/sampleworks/metrics/lddt.py +++ b/src/sampleworks/metrics/lddt.py @@ -336,11 +336,13 @@ def compute( # Set token ids to something useful for residue-level LDDT (chain ID + residue number) # Note: add_global_token_id_annotation works with both AtomArray and # AtomArrayStack at runtime - predicted_atom_array_stack = add_global_token_id_annotation( - predicted_atom_array_stack # ty: ignore[invalid-argument-type] + predicted_atom_array_stack = cast( + AtomArrayStack, + add_global_token_id_annotation(cast(Any, predicted_atom_array_stack)), ) - ground_truth_atom_array_stack = add_global_token_id_annotation( - ground_truth_atom_array_stack # ty: ignore[invalid-argument-type] + ground_truth_atom_array_stack = cast( + AtomArrayStack, + add_global_token_id_annotation(cast(Any, ground_truth_atom_array_stack)), ) # restrict to atoms that are present in both structures @@ -477,11 +479,13 @@ def compute( # set the token ids, to avoid any possible confusion later on # Note: add_global_token_id_annotation works with AtomArrayStack at runtime - filtered_predicted = add_global_token_id_annotation( - filtered_predicted # ty: ignore[invalid-argument-type] + filtered_predicted = cast( + AtomArrayStack, + add_global_token_id_annotation(cast(Any, filtered_predicted)), ) - filtered_ground_truth = add_global_token_id_annotation( - filtered_ground_truth # ty: ignore[invalid-argument-type] + filtered_ground_truth = cast( + AtomArrayStack, + add_global_token_id_annotation(cast(Any, filtered_ground_truth)), ) lddt_features = extract_lddt_features_from_atom_arrays( diff --git a/src/sampleworks/metrics/rmsd.py b/src/sampleworks/metrics/rmsd.py index e3362650..7bd90020 100644 --- a/src/sampleworks/metrics/rmsd.py +++ b/src/sampleworks/metrics/rmsd.py @@ -92,11 +92,13 @@ def compute( - ``best_of_{N}_segment_rmsd``: minimum of ``segment_rmsd``. """ # 1. Annotate token IDs so atoms can be grouped into residues downstream. - predicted_atom_array_stack = add_global_token_id_annotation( - predicted_atom_array_stack # ty: ignore[invalid-argument-type] (accepts AtomArray|AtomArrayStack at runtime; stub is narrower) + predicted_atom_array_stack = cast( + AtomArrayStack, + add_global_token_id_annotation(cast(Any, predicted_atom_array_stack)), ) - ground_truth_atom_array_stack = add_global_token_id_annotation( - ground_truth_atom_array_stack # ty: ignore[invalid-argument-type] (accepts AtomArray|AtomArrayStack at runtime; stub is narrower) + ground_truth_atom_array_stack = cast( + AtomArrayStack, + add_global_token_id_annotation(cast(Any, ground_truth_atom_array_stack)), ) # 2. Restrict both stacks to atoms present in both structures, in matching order. diff --git a/src/sampleworks/models/boltz/wrapper.py b/src/sampleworks/models/boltz/wrapper.py index d75b6ab5..5254341a 100644 --- a/src/sampleworks/models/boltz/wrapper.py +++ b/src/sampleworks/models/boltz/wrapper.py @@ -454,7 +454,7 @@ def create_boltz_input_from_structure( sequence_to_chains.setdefault(seq, []).append(chain_id) unique_chain_sequences = {chains[0]: seq for seq, chains in sequence_to_chains.items()} - msa_paths_unique = msa_manager.get_msa(unique_chain_sequences, msa_pairing_strategy) # ty: ignore[invalid-argument-type] + msa_paths_unique = msa_manager.get_msa(unique_chain_sequences, msa_pairing_strategy) msa_paths = {} for seq, chains_with_seq in sequence_to_chains.items(): @@ -739,7 +739,9 @@ def featurize(self, structure: dict) -> GenerativeModelInput[BoltzConditioning]: "x_init from prior. This means align_to_input will not work properly," " and reward functions dependent on this won't be accurate." ) - temp_features = GenerativeModelInput(x_init=None, conditioning=conditioning) # ty: ignore[invalid-argument-type] + temp_features = GenerativeModelInput[BoltzConditioning]( + x_init=cast(Any, None), conditioning=conditioning + ) x_init = self.initialize_from_prior(batch_size=ensemble_size, features=temp_features) return GenerativeModelInput(x_init=x_init, conditioning=conditioning) @@ -810,7 +812,7 @@ def _pairformer_pass( z = z + msa_module(z, s_inputs, features, use_kernels=self.model.use_kernels) if self.model.is_pairformer_compiled: - pairformer_module = self.model.pairformer_module._orig_mod + pairformer_module = cast(Any, self.model.pairformer_module)._orig_mod else: pairformer_module = self.model.pairformer_module @@ -1078,8 +1080,9 @@ def _setup_data_module( ) processed_dir = out_dir / "processed" + load_manifest = cast(Any, Manifest.load) processed = BoltzProcessedInput( - manifest=Manifest.load(processed_dir / "manifest.json"), + manifest=load_manifest(processed_dir / "manifest.json"), targets_dir=processed_dir / "structures", msa_dir=processed_dir / "msa", constraints_dir=(processed_dir / "constraints") @@ -1194,7 +1197,9 @@ def featurize(self, structure: dict) -> GenerativeModelInput[BoltzConditioning]: "x_init from prior. This means align_to_input will not work properly," " and reward functions dependent on this won't be accurate." ) - temp_features = GenerativeModelInput(x_init=None, conditioning=conditioning) # ty: ignore[invalid-argument-type] + temp_features = GenerativeModelInput[BoltzConditioning]( + x_init=cast(Any, None), conditioning=conditioning + ) x_init = self.initialize_from_prior(batch_size=ensemble_size, features=temp_features) return GenerativeModelInput(x_init=x_init, conditioning=conditioning) @@ -1368,7 +1373,7 @@ def _pairformer_pass( ) if self.model.is_pairformer_compiled: - pairformer_module = self.model.pairformer_module._orig_mod + pairformer_module = cast(Any, self.model.pairformer_module)._orig_mod else: pairformer_module = self.model.pairformer_module diff --git a/src/sampleworks/models/rf3/wrapper.py b/src/sampleworks/models/rf3/wrapper.py index 023371f6..f4fc6771 100644 --- a/src/sampleworks/models/rf3/wrapper.py +++ b/src/sampleworks/models/rf3/wrapper.py @@ -365,8 +365,10 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]: InferenceInput, input_batch[0] ) # since we're not batching, the loader returns a list of length 1 - # (Hydra instantiation of pipeline means it is going to be hard to type check here) - pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # ty: ignore[call-non-callable] + # Hydra instantiation leaves pipeline imprecisely typed even though it is callable + # at runtime. + pipeline = cast(Any, self.inference_engine.pipeline) + pipeline_output = pipeline(input_spec.to_pipeline_input()) pipeline_output = trainer.fabric.to_device(pipeline_output) features = trainer._assemble_network_inputs(pipeline_output) diff --git a/src/sampleworks/utils/framework_utils.py b/src/sampleworks/utils/framework_utils.py index fa5d61b3..143be4c0 100644 --- a/src/sampleworks/utils/framework_utils.py +++ b/src/sampleworks/utils/framework_utils.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Any, overload, ParamSpec, TYPE_CHECKING, TypeVar +from typing import Any, cast, overload, ParamSpec, TYPE_CHECKING, TypeVar import numpy as np @@ -184,8 +184,8 @@ def match_batch(array: Array | np.ndarray, target_batch_size: int) -> Array | np # singleton: lazy broadcast (no copy) if b == 1: - return _broadcast_to(array, (n, *array.shape[1:])) # ty: ignore[invalid-argument-type] + return _broadcast_to(cast(Any, array), (n, *array.shape[1:])) # divisible: tile if n % b: raise ValueError(f"batch {b} not divisible into target {n}") - return _tile(array, (n // b, *(1,) * (array.ndim - 1))) # ty: ignore[invalid-argument-type] + return _tile(cast(Any, array), (n // b, *(1,) * (array.ndim - 1))) diff --git a/src/sampleworks/utils/msa.py b/src/sampleworks/utils/msa.py index 83e28052..5ede9910 100644 --- a/src/sampleworks/utils/msa.py +++ b/src/sampleworks/utils/msa.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from hashlib import sha3_256 from pathlib import Path @@ -18,6 +19,14 @@ MAX_PAIRED_SEQS = 8192 MAX_MSA_SEQS = 16384 +MSAData = Mapping[str, str] | Mapping[int, str] | Mapping[str | int, str] + + +def _msa_data_key_sort_key(key: str | int) -> tuple[int, int | str]: + """Return a deterministic sort key for supported MSA sequence keys.""" + if isinstance(key, int): + return (0, key) + return (1, key) def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None: @@ -104,7 +113,7 @@ def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None: # For the love of decent code, don't copy this and use it somewhere else and respect the # leading underscore! def _compute_msa( - data: dict[str | int, str], + data: MSAData, target_id: str, msa_dir: Path, msa_server_url: str, @@ -118,7 +127,7 @@ def _compute_msa( Parameters ---------- - data : dict[str | int, str] + data : MSAData The input protein sequences. target_id : str The target id. @@ -192,7 +201,7 @@ def _compute_msa( # order as they are in `data`, and furthermore just returns a list of strings, the content # of each string being a single sequence alignment. It's some weird file parsing that we # should clean up so users don't break it or have to worry about it. - outputs = {} + outputs: dict[str | int, Path] = {} for idx, name in enumerate(data): # Get paired sequences paired = paired_msas[idx].strip().splitlines() @@ -308,14 +317,14 @@ def __init__( self._cache_hits = 0 @staticmethod - def _hash_arguments(data: dict[str | int, str], msa_pairing_strategy: str) -> str: + def _hash_arguments(data: MSAData, msa_pairing_strategy: str) -> str: encoded_sequence_tuple = str.encode(str(tuple(data.values())) + msa_pairing_strategy) hexdigest = sha3_256(encoded_sequence_tuple).hexdigest() return hexdigest def get_msa( self, - data: dict[str | int, str], + data: MSAData, msa_pairing_strategy: str, structure_predictor: str | StructurePredictor = StructurePredictor.BOLTZ_2, ) -> dict[str | int, Path]: @@ -323,7 +332,7 @@ def get_msa( Parameters ---------- - data : dict[str | int, str] + data : MSAData A dictionary mapping target (usu. chain or index) names to protein sequences. msa_pairing_strategy : str The MSA pairing strategy to use (usually "greedy"). @@ -382,8 +391,9 @@ def get_msa( protenix_dir.mkdir(parents=True, exist_ok=True) # Protenix adds extra information, easiest just to use their pipeline. # make sure sort order stays the same: - data_keys = sorted(data.keys()) - sequences = [data[key] for key in data_keys] + data_items = sorted(data.items(), key=lambda item: _msa_data_key_sort_key(item[0])) + data_keys = [key for key, _ in data_items] + sequences = [sequence for _, sequence in data_items] out_dir = self.msa_dir / "protenix" / hash_key msa_directories = [out_dir / str(idx) for idx in data_keys] @@ -392,7 +402,7 @@ def get_msa( (out_dir / str(idx) / fn).exists() for idx in data_keys for fn in reqd_files ) if need_msas: - msa_directories = protenix_msa_search(sequences, out_dir, mode="protenix") + msa_directories = protenix_msa_search(sequences, str(out_dir), mode="protenix") self._api_calls += 1 else: self._cache_hits += 1 diff --git a/tests/conftest.py b/tests/conftest.py index 4fba62e4..e84d889e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1089,5 +1089,7 @@ def perturbed_coords( """ torch.manual_seed(42) base = converging_mock_wrapper.target - perturbation = torch.randn_like(base) * 0.1 # ty: ignore[invalid-argument-type] - return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator] + if base is None: + raise ValueError("converging_mock_wrapper must provide target coordinates") + perturbation = torch.randn_like(base) * 0.1 + return base, base + perturbation diff --git a/tests/utils/test_atom_array_utils.py b/tests/utils/test_atom_array_utils.py index 76f0f2b1..47271c2a 100644 --- a/tests/utils/test_atom_array_utils.py +++ b/tests/utils/test_atom_array_utils.py @@ -1,6 +1,6 @@ """Tests for atom_array_utils module.""" -from typing import cast +from typing import Any, cast import numpy as np import pytest @@ -200,12 +200,12 @@ def test_invalid_type_raises_error(self): invalid_input = "not an atom array" with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"): - select_altloc(invalid_input, "A") # ty: ignore[invalid-argument-type] + select_altloc(cast(Any, invalid_input), "A") def test_none_input_raises_error(self): """Test that None input raises TypeError.""" with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"): - select_altloc(None, "A") # ty: ignore[invalid-argument-type] + select_altloc(cast(Any, None), "A") class TestSelectAltlocEdgeCases: @@ -358,7 +358,7 @@ def test_invalid_type_first_arg(self): array.coord = np.random.rand(3, 3) with pytest.raises(TypeError, match="must be AtomArray or AtomArrayStack"): - filter_to_common_atoms("not an array", array) # ty: ignore[invalid-argument-type] + filter_to_common_atoms(cast(Any, "not an array"), array) def test_invalid_type_second_arg(self): """Test that invalid second argument raises TypeError.""" @@ -366,7 +366,7 @@ def test_invalid_type_second_arg(self): array.coord = np.random.rand(3, 3) with pytest.raises(TypeError, match="must be AtomArray or AtomArrayStack"): - filter_to_common_atoms(array, None) # ty: ignore[invalid-argument-type] + filter_to_common_atoms(array, cast(Any, None)) def test_preserves_coordinates(self, atom_array_partial_overlap): """Test that coordinates are preserved for common atoms.""" @@ -443,7 +443,7 @@ def test_is_idempotent(self, structure_fixture, request): def test_invalid_type_raises_error(self): with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"): - remove_hydrogens("bad input") # ty:ignore[invalid-argument-type] + remove_hydrogens(cast(Any, "bad input")) class TestKeepPolymer: @@ -467,7 +467,7 @@ def test_atom_array_and_stack_give_same_atom_count(self, test_structure): def test_invalid_type_raises_error(self): with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"): - keep_polymer(42) # ty:ignore[invalid-argument-type] + keep_polymer(cast(Any, 42)) class TestKeepAminoAcids: @@ -494,7 +494,7 @@ def test_atom_array_and_stack_give_same_atom_count(self, test_structure): def test_invalid_type_raises_error(self): with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"): - keep_amino_acids(None) # ty:ignore[invalid-argument-type] + keep_amino_acids(cast(Any, None)) class TestFilterFunctionsIntegration: diff --git a/tests/utils/test_framework_utils.py b/tests/utils/test_framework_utils.py index 1738b6ae..5905cf3d 100644 --- a/tests/utils/test_framework_utils.py +++ b/tests/utils/test_framework_utils.py @@ -1,5 +1,7 @@ """Tests for framework conversion utilities from sampleworks.utils.framework_utils.""" +from typing import Any, cast + import numpy as np import pytest from sampleworks.utils.framework_utils import ( @@ -527,7 +529,7 @@ def test_raises_on_incompatible_batch_sizes(self): def test_raises_on_scalar_input(self): scalar = np.float64(1.0) with pytest.raises(ValueError, match="ndim >= 1"): - match_batch(scalar, target_batch_size=2) # ty: ignore[no-matching-overload] + match_batch(cast(Any, scalar), target_batch_size=2) def test_1d_array(self): array = np.array([42.0]) @@ -549,4 +551,4 @@ def test_preserves_dtype(self): def test_raises_on_unsupported_type(self): with pytest.raises(TypeError, match="unsupported array type"): - match_batch([1, 2, 3], target_batch_size=2) # ty: ignore[no-matching-overload] + match_batch(cast(Any, [1, 2, 3]), target_batch_size=2) diff --git a/tests/utils/test_msa.py b/tests/utils/test_msa.py index 412fb5d0..25f41519 100644 --- a/tests/utils/test_msa.py +++ b/tests/utils/test_msa.py @@ -8,7 +8,7 @@ import pytest from sampleworks.utils import msa as msa_module from sampleworks.utils.guidance_constants import StructurePredictor -from sampleworks.utils.msa import _compute_msa, MSAManager +from sampleworks.utils.msa import _compute_msa, _msa_data_key_sort_key, MSAManager # ============================================================================ @@ -30,6 +30,13 @@ def test_hash_arguments_is_deterministic_and_input_sensitive(): assert base != MSAManager._hash_arguments({"B": "GGGAA", "A": "MKTAY"}, "greedy") +def test_msa_data_key_sort_preserves_numeric_order_and_handles_mixed_keys(): + """Ensure Protenix MSA key sorting preserves numeric order and mixed-key stability.""" + keys = [10, 2, "1", "A"] + + assert sorted(keys, key=_msa_data_key_sort_key) == [2, 10, "1", "A"] + + # ============================================================================ # __init__ cache directory handling (test 4) # ============================================================================