Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions src/sampleworks/utils/msa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Mapping
from hashlib import sha3_256
from pathlib import Path

Expand All @@ -18,6 +19,14 @@

MAX_PAIRED_SEQS = 8192
MAX_MSA_SEQS = 16384
MSAData = Mapping[str, str] | Mapping[int, str] | Mapping[str | int, str]


def _msa_data_key_sort_key(key: str | int) -> tuple[int, int | str]:
"""Return a deterministic sort key for supported MSA sequence keys."""
if isinstance(key, int):
return (0, key)
return (1, key)


def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None:
Expand Down Expand Up @@ -104,7 +113,7 @@ def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None:
# For the love of decent code, don't copy this and use it somewhere else and respect the
# leading underscore!
def _compute_msa(
data: dict[str | int, str],
data: MSAData,
target_id: str,
Comment on lines 115 to 117
msa_dir: Path,
msa_server_url: str,
Expand All @@ -118,7 +127,7 @@ def _compute_msa(

Parameters
----------
data : dict[str | int, str]
data : MSAData
The input protein sequences.
target_id : str
The target id.
Expand Down Expand Up @@ -192,7 +201,7 @@ def _compute_msa(
# order as they are in `data`, and furthermore just returns a list of strings, the content
# of each string being a single sequence alignment. It's some weird file parsing that we
# should clean up so users don't break it or have to worry about it.
outputs = {}
outputs: dict[str | int, Path] = {}
for idx, name in enumerate(data):
# Get paired sequences
paired = paired_msas[idx].strip().splitlines()
Expand Down Expand Up @@ -308,22 +317,22 @@ def __init__(
self._cache_hits = 0

@staticmethod
def _hash_arguments(data: dict[str | int, str], msa_pairing_strategy: str) -> str:
def _hash_arguments(data: MSAData, msa_pairing_strategy: str) -> str:
encoded_sequence_tuple = str.encode(str(tuple(data.values())) + msa_pairing_strategy)
hexdigest = sha3_256(encoded_sequence_tuple).hexdigest()
return hexdigest

def get_msa(
self,
data: dict[str | int, str],
data: MSAData,
msa_pairing_strategy: str,
Comment on lines 325 to 328
structure_predictor: str | StructurePredictor = StructurePredictor.BOLTZ_2,
) -> dict[str | int, Path]:
"""Fetches existing MSA files from disk or computes new ones if necessary.

Parameters
----------
data : dict[str | int, str]
data : MSAData
A dictionary mapping target (usu. chain or index) names to protein sequences.
msa_pairing_strategy : str
The MSA pairing strategy to use (usually "greedy").
Expand Down Expand Up @@ -382,8 +391,9 @@ def get_msa(
protenix_dir.mkdir(parents=True, exist_ok=True)
# Protenix adds extra information, easiest just to use their pipeline.
# make sure sort order stays the same:
data_keys = sorted(data.keys())
sequences = [data[key] for key in data_keys]
data_items = sorted(data.items(), key=lambda item: _msa_data_key_sort_key(item[0]))
data_keys = [key for key, _ in data_items]
sequences = [sequence for _, sequence in data_items]
out_dir = self.msa_dir / "protenix" / hash_key

msa_directories = [out_dir / str(idx) for idx in data_keys]
Expand All @@ -392,7 +402,7 @@ def get_msa(
(out_dir / str(idx) / fn).exists() for idx in data_keys for fn in reqd_files
)
if need_msas:
msa_directories = protenix_msa_search(sequences, out_dir, mode="protenix")
msa_directories = protenix_msa_search(sequences, str(out_dir), mode="protenix")
self._api_calls += 1
else:
self._cache_hits += 1
Expand Down
9 changes: 8 additions & 1 deletion tests/utils/test_msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from sampleworks.utils import msa as msa_module
from sampleworks.utils.guidance_constants import StructurePredictor
from sampleworks.utils.msa import _compute_msa, MSAManager
from sampleworks.utils.msa import _compute_msa, _msa_data_key_sort_key, MSAManager


# ============================================================================
Expand All @@ -30,6 +30,13 @@ def test_hash_arguments_is_deterministic_and_input_sensitive():
assert base != MSAManager._hash_arguments({"B": "GGGAA", "A": "MKTAY"}, "greedy")


def test_msa_data_key_sort_preserves_numeric_order_and_handles_mixed_keys():
"""Ensure Protenix MSA key sorting preserves numeric order and mixed-key stability."""
keys = [10, 2, "1", "A"]

assert sorted(keys, key=_msa_data_key_sort_key) == [2, 10, "1", "A"]


# ============================================================================
# __init__ cache directory handling (test 4)
# ============================================================================
Expand Down
Loading