Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# ty: ignore
# ruff: ignore
"""Cromer-Mann coefficents and related constants

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# ty: ignore

"""Minimal components of qFit's transformer module necessary to implement
qFit's volume.py

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# type: ignore
# ty: ignore

"""Classes for handling unit cell transformation."""

from itertools import product
Expand Down
9 changes: 5 additions & 4 deletions src/sampleworks/core/samplers/edm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import cast, TYPE_CHECKING

import einx
import torch
Expand Down Expand Up @@ -395,7 +395,8 @@ def step(

t_hat = context.t_effective
dt = context.dt
eps_scale = context.noise_scale
# check_context() guarantees noise_scale is present for sampling steps.
eps_scale = cast(float, context.noise_scale)
allow_gradients = True if scaler and getattr(scaler, "requires_gradients", False) else False

centroid = einx.mean("... [n] c", state)
Expand All @@ -415,7 +416,7 @@ def step(

# Store eps separately for proper frame transformation
# eps_scale will be float if check_context didn't raise
eps = torch.randn_like(maybe_augmented_state) * eps_scale # ty: ignore[unsupported-operator]
eps = torch.randn_like(maybe_augmented_state) * eps_scale
noisy_state = maybe_augmented_state + eps
noisy_state = torch.as_tensor(noisy_state).detach().requires_grad_(allow_gradients)

Expand Down Expand Up @@ -507,7 +508,7 @@ def step(
# multiple particles.
# Only compute when noise_var > 0 to avoid division by near-zero
# (matching Boltz behavior)
noise_var = eps_scale**2 # ty: ignore[unsupported-operator]
noise_var = eps_scale**2
if noise_var > 0:
log_proposal_correction = einx.sum(
"... [b n c]", eps_working_frame**2 - (eps_working_frame + proposal_shift) ** 2
Expand Down
20 changes: 12 additions & 8 deletions src/sampleworks/metrics/lddt.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,13 @@ def compute(
# Set token ids to something useful for residue-level LDDT (chain ID + residue number)
# Note: add_global_token_id_annotation works with both AtomArray and
# AtomArrayStack at runtime
predicted_atom_array_stack = add_global_token_id_annotation(
predicted_atom_array_stack # ty: ignore[invalid-argument-type]
predicted_atom_array_stack = cast(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like something we should fix in AtomWorks or push into our own macromolecular-observables package. The AtomWorks problem is that they don't properly annotate the input/output types. Probably we should just put it in our own package and move away from AtomWorks.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any event, I dislike cast because of the visual distraction, so we should fix the underlying problem, but for this PR I don't think anything needs to be done except to create an issue.

AtomArrayStack,
add_global_token_id_annotation(cast(Any, predicted_atom_array_stack)),
)
ground_truth_atom_array_stack = add_global_token_id_annotation(
ground_truth_atom_array_stack # ty: ignore[invalid-argument-type]
ground_truth_atom_array_stack = cast(
AtomArrayStack,
add_global_token_id_annotation(cast(Any, ground_truth_atom_array_stack)),
)

# restrict to atoms that are present in both structures
Expand Down Expand Up @@ -477,11 +479,13 @@ def compute(

# set the token ids, to avoid any possible confusion later on
# Note: add_global_token_id_annotation works with AtomArrayStack at runtime
filtered_predicted = add_global_token_id_annotation(
filtered_predicted # ty: ignore[invalid-argument-type]
filtered_predicted = cast(
AtomArrayStack,
add_global_token_id_annotation(cast(Any, filtered_predicted)),
)
filtered_ground_truth = add_global_token_id_annotation(
filtered_ground_truth # ty: ignore[invalid-argument-type]
filtered_ground_truth = cast(
AtomArrayStack,
add_global_token_id_annotation(cast(Any, filtered_ground_truth)),
)

lddt_features = extract_lddt_features_from_atom_arrays(
Expand Down
10 changes: 6 additions & 4 deletions src/sampleworks/metrics/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ def compute(
- ``best_of_{N}_segment_rmsd``: minimum of ``segment_rmsd``.
"""
# 1. Annotate token IDs so atoms can be grouped into residues downstream.
predicted_atom_array_stack = add_global_token_id_annotation(
predicted_atom_array_stack # ty: ignore[invalid-argument-type] (accepts AtomArray|AtomArrayStack at runtime; stub is narrower)
predicted_atom_array_stack = cast(
AtomArrayStack,
add_global_token_id_annotation(cast(Any, predicted_atom_array_stack)),
)
ground_truth_atom_array_stack = add_global_token_id_annotation(
ground_truth_atom_array_stack # ty: ignore[invalid-argument-type] (accepts AtomArray|AtomArrayStack at runtime; stub is narrower)
ground_truth_atom_array_stack = cast(
AtomArrayStack,
add_global_token_id_annotation(cast(Any, ground_truth_atom_array_stack)),
)

# 2. Restrict both stacks to atoms present in both structures, in matching order.
Expand Down
17 changes: 11 additions & 6 deletions src/sampleworks/models/boltz/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def create_boltz_input_from_structure(
sequence_to_chains.setdefault(seq, []).append(chain_id)

unique_chain_sequences = {chains[0]: seq for seq, chains in sequence_to_chains.items()}
msa_paths_unique = msa_manager.get_msa(unique_chain_sequences, msa_pairing_strategy) # ty: ignore[invalid-argument-type]
msa_paths_unique = msa_manager.get_msa(unique_chain_sequences, msa_pairing_strategy)

msa_paths = {}
for seq, chains_with_seq in sequence_to_chains.items():
Expand Down Expand Up @@ -739,7 +739,9 @@ def featurize(self, structure: dict) -> GenerativeModelInput[BoltzConditioning]:
"x_init from prior. This means align_to_input will not work properly,"
" and reward functions dependent on this won't be accurate."
)
temp_features = GenerativeModelInput(x_init=None, conditioning=conditioning) # ty: ignore[invalid-argument-type]
temp_features = GenerativeModelInput[BoltzConditioning](
x_init=cast(Any, None), conditioning=conditioning
)
x_init = self.initialize_from_prior(batch_size=ensemble_size, features=temp_features)

return GenerativeModelInput(x_init=x_init, conditioning=conditioning)
Expand Down Expand Up @@ -810,7 +812,7 @@ def _pairformer_pass(
z = z + msa_module(z, s_inputs, features, use_kernels=self.model.use_kernels)

if self.model.is_pairformer_compiled:
pairformer_module = self.model.pairformer_module._orig_mod
pairformer_module = cast(Any, self.model.pairformer_module)._orig_mod
else:
pairformer_module = self.model.pairformer_module

Expand Down Expand Up @@ -1078,8 +1080,9 @@ def _setup_data_module(
)

processed_dir = out_dir / "processed"
load_manifest = cast(Any, Manifest.load)
processed = BoltzProcessedInput(
manifest=Manifest.load(processed_dir / "manifest.json"),
manifest=load_manifest(processed_dir / "manifest.json"),
targets_dir=processed_dir / "structures",
msa_dir=processed_dir / "msa",
constraints_dir=(processed_dir / "constraints")
Expand Down Expand Up @@ -1194,7 +1197,9 @@ def featurize(self, structure: dict) -> GenerativeModelInput[BoltzConditioning]:
"x_init from prior. This means align_to_input will not work properly,"
" and reward functions dependent on this won't be accurate."
)
temp_features = GenerativeModelInput(x_init=None, conditioning=conditioning) # ty: ignore[invalid-argument-type]
temp_features = GenerativeModelInput[BoltzConditioning](
x_init=cast(Any, None), conditioning=conditioning
)
x_init = self.initialize_from_prior(batch_size=ensemble_size, features=temp_features)

return GenerativeModelInput(x_init=x_init, conditioning=conditioning)
Expand Down Expand Up @@ -1368,7 +1373,7 @@ def _pairformer_pass(
)

if self.model.is_pairformer_compiled:
pairformer_module = self.model.pairformer_module._orig_mod
pairformer_module = cast(Any, self.model.pairformer_module)._orig_mod
else:
pairformer_module = self.model.pairformer_module

Expand Down
6 changes: 4 additions & 2 deletions src/sampleworks/models/rf3/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,10 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]:
InferenceInput, input_batch[0]
) # since we're not batching, the loader returns a list of length 1

# (Hydra instantiation of pipeline means it is going to be hard to type check here)
pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # ty: ignore[call-non-callable]
# Hydra instantiation leaves pipeline imprecisely typed even though it is callable
# at runtime.
pipeline = cast(Any, self.inference_engine.pipeline)
pipeline_output = pipeline(input_spec.to_pipeline_input())
pipeline_output = trainer.fabric.to_device(pipeline_output)

features = trainer._assemble_network_inputs(pipeline_output)
Expand Down
6 changes: 3 additions & 3 deletions src/sampleworks/utils/framework_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Any, overload, ParamSpec, TYPE_CHECKING, TypeVar
from typing import Any, cast, overload, ParamSpec, TYPE_CHECKING, TypeVar

import numpy as np

Expand Down Expand Up @@ -184,8 +184,8 @@ def match_batch(array: Array | np.ndarray, target_batch_size: int) -> Array | np

# singleton: lazy broadcast (no copy)
if b == 1:
return _broadcast_to(array, (n, *array.shape[1:])) # ty: ignore[invalid-argument-type]
return _broadcast_to(cast(Any, array), (n, *array.shape[1:]))
# divisible: tile
if n % b:
raise ValueError(f"batch {b} not divisible into target {n}")
return _tile(array, (n // b, *(1,) * (array.ndim - 1))) # ty: ignore[invalid-argument-type]
return _tile(cast(Any, array), (n // b, *(1,) * (array.ndim - 1)))
28 changes: 19 additions & 9 deletions src/sampleworks/utils/msa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Mapping
from hashlib import sha3_256
from pathlib import Path

Expand All @@ -18,6 +19,14 @@

MAX_PAIRED_SEQS = 8192
MAX_MSA_SEQS = 16384
MSAData = Mapping[str, str] | Mapping[int, str] | Mapping[str | int, str]

Comment on lines 20 to +23

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xraymemory I think I agree with copilot here, and would use dict explicitly. Does that not get rid of the warnings for _compute_msa?


def _msa_data_key_sort_key(key: str | int) -> tuple[int, int | str]:
"""Return a deterministic sort key for supported MSA sequence keys."""
if isinstance(key, int):
return (0, key)
return (1, key)


def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None:
Expand Down Expand Up @@ -104,7 +113,7 @@ def _validate_msa_cache_contents(msa_hash: str, msa_dir: Path) -> None:
# For the love of decent code, don't copy this and use it somewhere else and respect the
# leading underscore!
def _compute_msa(
data: dict[str | int, str],
data: MSAData,
target_id: str,
msa_dir: Path,
msa_server_url: str,
Expand All @@ -118,7 +127,7 @@ def _compute_msa(

Parameters
----------
data : dict[str | int, str]
data : MSAData
The input protein sequences.
target_id : str
The target id.
Expand Down Expand Up @@ -192,7 +201,7 @@ def _compute_msa(
# order as they are in `data`, and furthermore just returns a list of strings, the content
# of each string being a single sequence alignment. It's some weird file parsing that we
# should clean up so users don't break it or have to worry about it.
outputs = {}
outputs: dict[str | int, Path] = {}
for idx, name in enumerate(data):
# Get paired sequences
paired = paired_msas[idx].strip().splitlines()
Expand Down Expand Up @@ -308,22 +317,22 @@ def __init__(
self._cache_hits = 0

@staticmethod
def _hash_arguments(data: dict[str | int, str], msa_pairing_strategy: str) -> str:
def _hash_arguments(data: MSAData, msa_pairing_strategy: str) -> str:
encoded_sequence_tuple = str.encode(str(tuple(data.values())) + msa_pairing_strategy)
hexdigest = sha3_256(encoded_sequence_tuple).hexdigest()
return hexdigest

def get_msa(
self,
data: dict[str | int, str],
data: MSAData,
msa_pairing_strategy: str,
structure_predictor: str | StructurePredictor = StructurePredictor.BOLTZ_2,
) -> dict[str | int, Path]:
"""Fetches existing MSA files from disk or computes new ones if necessary.

Parameters
----------
data : dict[str | int, str]
data : MSAData
A dictionary mapping target (usu. chain or index) names to protein sequences.
msa_pairing_strategy : str
The MSA pairing strategy to use (usually "greedy").
Expand Down Expand Up @@ -382,8 +391,9 @@ def get_msa(
protenix_dir.mkdir(parents=True, exist_ok=True)
# Protenix adds extra information, easiest just to use their pipeline.
# make sure sort order stays the same:
data_keys = sorted(data.keys())
sequences = [data[key] for key in data_keys]
data_items = sorted(data.items(), key=lambda item: _msa_data_key_sort_key(item[0]))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment what _msa_data_key_sort_key does (and maybe give it a better name?)

data_keys = [key for key, _ in data_items]
sequences = [sequence for _, sequence in data_items]
out_dir = self.msa_dir / "protenix" / hash_key

msa_directories = [out_dir / str(idx) for idx in data_keys]
Expand All @@ -392,7 +402,7 @@ def get_msa(
(out_dir / str(idx) / fn).exists() for idx in data_keys for fn in reqd_files
)
if need_msas:
msa_directories = protenix_msa_search(sequences, out_dir, mode="protenix")
msa_directories = protenix_msa_search(sequences, str(out_dir), mode="protenix")
self._api_calls += 1
else:
self._cache_hits += 1
Expand Down
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,5 +1089,7 @@ def perturbed_coords(
"""
torch.manual_seed(42)
base = converging_mock_wrapper.target
perturbation = torch.randn_like(base) * 0.1 # ty: ignore[invalid-argument-type]
return base, base + perturbation # ty: ignore[invalid-return-type, unsupported-operator]
if base is None:
raise ValueError("converging_mock_wrapper must provide target coordinates")
perturbation = torch.randn_like(base) * 0.1
return base, base + perturbation
16 changes: 8 additions & 8 deletions tests/utils/test_atom_array_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for atom_array_utils module."""

from typing import cast
from typing import Any, cast

import numpy as np
import pytest
Expand Down Expand Up @@ -200,12 +200,12 @@ def test_invalid_type_raises_error(self):
invalid_input = "not an atom array"

with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"):
select_altloc(invalid_input, "A") # ty: ignore[invalid-argument-type]
select_altloc(cast(Any, invalid_input), "A")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this doesn't still flag it as an invalid argument type. Can you explain?


def test_none_input_raises_error(self):
"""Test that None input raises TypeError."""
with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"):
select_altloc(None, "A") # ty: ignore[invalid-argument-type]
select_altloc(cast(Any, None), "A")


class TestSelectAltlocEdgeCases:
Expand Down Expand Up @@ -358,15 +358,15 @@ def test_invalid_type_first_arg(self):
array.coord = np.random.rand(3, 3)

with pytest.raises(TypeError, match="must be AtomArray or AtomArrayStack"):
filter_to_common_atoms("not an array", array) # ty: ignore[invalid-argument-type]
filter_to_common_atoms(cast(Any, "not an array"), array)

def test_invalid_type_second_arg(self):
"""Test that invalid second argument raises TypeError."""
array = AtomArray(3)
array.coord = np.random.rand(3, 3)

with pytest.raises(TypeError, match="must be AtomArray or AtomArrayStack"):
filter_to_common_atoms(array, None) # ty: ignore[invalid-argument-type]
filter_to_common_atoms(array, cast(Any, None))

def test_preserves_coordinates(self, atom_array_partial_overlap):
"""Test that coordinates are preserved for common atoms."""
Expand Down Expand Up @@ -443,7 +443,7 @@ def test_is_idempotent(self, structure_fixture, request):

def test_invalid_type_raises_error(self):
with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"):
remove_hydrogens("bad input") # ty:ignore[invalid-argument-type]
remove_hydrogens(cast(Any, "bad input"))


class TestKeepPolymer:
Expand All @@ -467,7 +467,7 @@ def test_atom_array_and_stack_give_same_atom_count(self, test_structure):

def test_invalid_type_raises_error(self):
with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"):
keep_polymer(42) # ty:ignore[invalid-argument-type]
keep_polymer(cast(Any, 42))


class TestKeepAminoAcids:
Expand All @@ -494,7 +494,7 @@ def test_atom_array_and_stack_give_same_atom_count(self, test_structure):

def test_invalid_type_raises_error(self):
with pytest.raises(TypeError, match="can only accept AtomArray or AtomArrayStack"):
keep_amino_acids(None) # ty:ignore[invalid-argument-type]
keep_amino_acids(cast(Any, None))


class TestFilterFunctionsIntegration:
Expand Down
6 changes: 4 additions & 2 deletions tests/utils/test_framework_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for framework conversion utilities from sampleworks.utils.framework_utils."""

from typing import Any, cast

import numpy as np
import pytest
from sampleworks.utils.framework_utils import (
Expand Down Expand Up @@ -527,7 +529,7 @@ def test_raises_on_incompatible_batch_sizes(self):
def test_raises_on_scalar_input(self):
scalar = np.float64(1.0)
with pytest.raises(ValueError, match="ndim >= 1"):
match_batch(scalar, target_batch_size=2) # ty: ignore[no-matching-overload]
match_batch(cast(Any, scalar), target_batch_size=2)

def test_1d_array(self):
array = np.array([42.0])
Expand All @@ -549,4 +551,4 @@ def test_preserves_dtype(self):

def test_raises_on_unsupported_type(self):
with pytest.raises(TypeError, match="unsupported array type"):
match_batch([1, 2, 3], target_batch_size=2) # ty: ignore[no-matching-overload]
match_batch(cast(Any, [1, 2, 3]), target_batch_size=2)
Loading
Loading