diff --git a/src/sampleworks/utils/msa.py b/src/sampleworks/utils/msa.py index 83e2805..5ede991 100644 --- a/src/sampleworks/utils/msa.py +++ b/src/sampleworks/utils/msa.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from hashlib import sha3_256 from pathlib import Path @@ -18,6 +19,14 @@ MAX_PAIRED_SEQS = 8192 MAX_MSA_SEQS = 16384 +MSAData = Mapping[str, str] | Mapping[int, str] | Mapping[str | int, str] + + +def _msa_data_key_sort_key(key: str | int) -> tuple[int, int | str]: + """Return a deterministic sort key for supported MSA sequence keys.""" + if isinstance(key, int): + return (0, key) + return (1, key) def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None: @@ -104,7 +113,7 @@ def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None: # For the love of decent code, don't copy this and use it somewhere else and respect the # leading underscore! def _compute_msa( - data: dict[str | int, str], + data: MSAData, target_id: str, msa_dir: Path, msa_server_url: str, @@ -118,7 +127,7 @@ def _compute_msa( Parameters ---------- - data : dict[str | int, str] + data : MSAData The input protein sequences. target_id : str The target id. @@ -192,7 +201,7 @@ def _compute_msa( # order as they are in `data`, and furthermore just returns a list of strings, the content # of each string being a single sequence alignment. It's some weird file parsing that we # should clean up so users don't break it or have to worry about it. - outputs = {} + outputs: dict[str | int, Path] = {} for idx, name in enumerate(data): # Get paired sequences paired = paired_msas[idx].strip().splitlines() @@ -308,14 +317,14 @@ def __init__( self._cache_hits = 0 @staticmethod - def _hash_arguments(data: dict[str | int, str], msa_pairing_strategy: str) -> str: + def _hash_arguments(data: MSAData, msa_pairing_strategy: str) -> str: encoded_sequence_tuple = str.encode(str(tuple(data.values())) + msa_pairing_strategy) hexdigest = sha3_256(encoded_sequence_tuple).hexdigest() return hexdigest def get_msa( self, - data: dict[str | int, str], + data: MSAData, msa_pairing_strategy: str, structure_predictor: str | StructurePredictor = StructurePredictor.BOLTZ_2, ) -> dict[str | int, Path]: @@ -323,7 +332,7 @@ def get_msa( Parameters ---------- - data : dict[str | int, str] + data : MSAData A dictionary mapping target (usu. chain or index) names to protein sequences. msa_pairing_strategy : str The MSA pairing strategy to use (usually "greedy"). @@ -382,8 +391,9 @@ def get_msa( protenix_dir.mkdir(parents=True, exist_ok=True) # Protenix adds extra information, easiest just to use their pipeline. # make sure sort order stays the same: - data_keys = sorted(data.keys()) - sequences = [data[key] for key in data_keys] + data_items = sorted(data.items(), key=lambda item: _msa_data_key_sort_key(item[0])) + data_keys = [key for key, _ in data_items] + sequences = [sequence for _, sequence in data_items] out_dir = self.msa_dir / "protenix" / hash_key msa_directories = [out_dir / str(idx) for idx in data_keys] @@ -392,7 +402,7 @@ def get_msa( (out_dir / str(idx) / fn).exists() for idx in data_keys for fn in reqd_files ) if need_msas: - msa_directories = protenix_msa_search(sequences, out_dir, mode="protenix") + msa_directories = protenix_msa_search(sequences, str(out_dir), mode="protenix") self._api_calls += 1 else: self._cache_hits += 1 diff --git a/tests/utils/test_msa.py b/tests/utils/test_msa.py index 412fb5d..25f4151 100644 --- a/tests/utils/test_msa.py +++ b/tests/utils/test_msa.py @@ -8,7 +8,7 @@ import pytest from sampleworks.utils import msa as msa_module from sampleworks.utils.guidance_constants import StructurePredictor -from sampleworks.utils.msa import _compute_msa, MSAManager +from sampleworks.utils.msa import _compute_msa, _msa_data_key_sort_key, MSAManager # ============================================================================ @@ -30,6 +30,13 @@ def test_hash_arguments_is_deterministic_and_input_sensitive(): assert base != MSAManager._hash_arguments({"B": "GGGAA", "A": "MKTAY"}, "greedy") +def test_msa_data_key_sort_preserves_numeric_order_and_handles_mixed_keys(): + """Ensure Protenix MSA key sorting preserves numeric order and mixed-key stability.""" + keys = [10, 2, "1", "A"] + + assert sorted(keys, key=_msa_data_key_sort_key) == [2, 10, "1", "A"] + + # ============================================================================ # __init__ cache directory handling (test 4) # ============================================================================