Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"e3nn",
"esm",
"biotite",
"pymol-open-source",
"pymol-open-source-whl>=3.1.0.4",
"scipy",
"pandas",
"numpy",
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate_esm_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def compute_esm_embeddings(
"""
try:
# Load ground truth atoms using geometry cache parser in src/dataset.py
protein_atoms, _ = parse_asu_with_biotite(str(pdb_path))
protein_atoms, _, _ = parse_asu_with_biotite(str(pdb_path))
if len(protein_atoms) == 0:
raise ValueError(f"No protein atoms found in {pdb_path}")

Expand Down
6 changes: 5 additions & 1 deletion scripts/generate_slae_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
Precompute SLAE embeddings for protein structures and save to separate cache files.

NOTE: This SLAE encoder is legacy and is NOT currently used. We primarily use the
ESM encoder (see scripts/generate_esm_embeddings.py). This script is retained for
reference/reproducibility only.

This script:
1. Reads a split file containing PDB entries
2. For each entry, loads the PDB and converts to atom37 representation
Expand Down Expand Up @@ -400,7 +404,7 @@ def main() -> None:

try:
# protein_atoms: biotite AtomArray with num_atoms atoms
protein_atoms, _ = parse_asu_with_biotite(str(pdb_path))
protein_atoms, _, _ = parse_asu_with_biotite(str(pdb_path))
# coords: (num_residues, 37, 3) - atom37 coordinates
# residue_type: (num_residues,) - residue type indices
# chains: (num_residues,) - chain IDs
Expand Down
65 changes: 58 additions & 7 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,21 @@ def element_onehot(symbols: list[str]) -> Tensor:

def parse_asu_with_biotite(
path: str,
) -> tuple[bts.AtomArray, bts.AtomArray]:
) -> tuple[bts.AtomArray, bts.AtomArray, bts.AtomArray]:
"""
Comment thread
vratins marked this conversation as resolved.
Parse PDB file and extract protein and water atoms.
Parse PDB file and extract protein, water, and ligand atoms.

Args:
path: Path to PDB file

Returns:
Tuple of (protein_atoms, water_atoms) as biotite AtomArrays.
Hydrogen atoms are excluded.
Tuple of (protein_atoms, water_atoms, ligand_atoms) as biotite AtomArrays.
Hydrogen atoms are excluded. ligand_atoms contains every non-protein,
non-water heavy atom: small-molecule ligands, ions, cofactors, AND
non-amino-acid polymers such as nucleic acids (DNA/RNA). It is deliberately
NOT restricted to HETATM records -- nucleic acids are written as ATOM
records but are kept here as context (their surfaces, especially the
phosphate backbone, order nearby water).

Notes:
- model=1: Uses first model in PDB (standard for X-ray structures)
Expand All @@ -72,10 +77,15 @@ def parse_asu_with_biotite(
protein_mask = bts.filter_amino_acids(atoms)
water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")

# "ligand" here is broad: every non-protein, non-water heavy atom.
# includes small-molecule ligands, ions, cofactors and even nucleic acids
ligand_mask = ~protein_mask & ~water_mask

protein_atoms = atoms[protein_mask]
water_atoms = atoms[water_mask]
ligand_atoms = atoms[ligand_mask]
Comment thread
vratins marked this conversation as resolved.

return protein_atoms, water_atoms
return protein_atoms, water_atoms, ligand_atoms
Comment thread
vratins marked this conversation as resolved.


def get_crystal_contacts_pymol(
Expand Down Expand Up @@ -665,6 +675,7 @@ def __init__(
base_pdb_dir: str = "/sb/wankowicz_lab/data/srivasv/pdb_redo_data",
cutoff: float = 8.0,
include_mates: bool = True,
include_ligands: bool = False,
geometry_cache_name: str = "geometry",
preprocess: bool = True,
duplicate_single_sample: int = 1,
Expand All @@ -691,8 +702,14 @@ def __init__(
base_pdb_dir: Base directory containing PDB subdirectories
cutoff: Distance cutoff for PP edges and crystal contacts (Angstroms)
include_mates: If True, include symmetry mate atoms as protein nodes
include_ligands: If True, include every non-protein, non-water heavy
atom (small-molecule ligands, ions, cofactors, and
nucleic acids) as protein-type nodes. They are appended
after protein (and mate) atoms with a boolean is_ligand
mask and residue_index = -1.
geometry_cache_name: Base name for geometry cache directory. When
include_mates=True, "_mates" is appended automatically.
When include_ligands=True, "_lig" is appended.
Default is "geometry", resulting in "geometry/" or
"geometry_mates/" subdirectories.
preprocess: If True, run preprocessing on missing cached files
Expand Down Expand Up @@ -720,8 +737,10 @@ def __init__(
"""

self.cache_dir = Path(processed_dir)
# Directory-based separation: geometry/ vs geometry_mates/
# Directory-based separation: geometry/ vs geometry_mates/ vs geometry_lig/ etc.
cache_suffix = "_mates" if include_mates else ""
if include_ligands:
cache_suffix += "_lig"
self.geometry_dir = self.cache_dir / f"{geometry_cache_name}{cache_suffix}"
self.base_pdb_dir = Path(base_pdb_dir)
self.cutoff = cutoff
Expand All @@ -731,6 +750,7 @@ def __init__(
else:
self.embedding_dir = None
self.include_mates = include_mates
self.include_ligands = include_ligands
self.duplicate_single_sample = duplicate_single_sample

self.max_com_dist = max_com_dist
Expand Down Expand Up @@ -867,7 +887,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
"""
pdb_path = str(entry["pdb_path"])

protein_atoms, water_atoms = parse_asu_with_biotite(pdb_path)
protein_atoms, water_atoms, ligand_atoms = parse_asu_with_biotite(pdb_path)

# check inter-chain interactions for multi-chain proteins
chain_valid, chain_reason, _ = check_chain_interactions(
Expand Down Expand Up @@ -1010,6 +1030,8 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
water_x = torch.zeros((0, len(ELEMENT_VOCAB) + 1), dtype=torch.float32)

# process symmetry mate atoms
# NOTE(ligands+mates): mate ligand het atoms belong here too when
# include_ligands is set -- see TODO at the ASU ligand-append block below.
mate_coords = crystal_data["mate_coords"]
if mate_coords.shape[0] > 0:
mate_pos = torch.tensor(mate_coords, dtype=torch.float32) - center
Expand Down Expand Up @@ -1046,6 +1068,32 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
final_protein_x = protein_x
final_protein_res_idx = protein_res_idx

# Append ligand atoms after protein (and mate) atoms when enabled.
# is_ligand mask marks which protein-type nodes are ligand atoms.
# Ligands always go last so num_asu_protein and mate counts are unaffected,
# preserving ESM/SLAE embedding alignment via _pad_atom_embeddings_for_mates.

# TODO(ligands+mates): this only adds ASU ligands. Until dev_crystal_mates
# is opened for a PR, mates are restricted to polymer.protein, so a ligand sitting in a
# crystal contact is dropped from the mates
if self.include_ligands and len(ligand_atoms) > 0:
ligand_pos = torch.tensor(ligand_atoms.coord, dtype=torch.float32) - center
ligand_elements = [str(e).upper() for e in ligand_atoms.element]
ligand_x = element_onehot(ligand_elements)
final_protein_pos = torch.cat([final_protein_pos, ligand_pos], dim=0)
final_protein_x = torch.cat([final_protein_x, ligand_x], dim=0)
# Ligand atoms get residue_index = -1 (sentinel; no residue embedding).
# The is_ligand mask identifies them; residue-pooling masks out these
# negative indices before any scatter (see GVPEncoder._pool_by_residue).
ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long)
final_protein_res_idx = torch.cat(
[final_protein_res_idx, ligand_res_idx], dim=0
)
is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool)
is_ligand[-len(ligand_atoms) :] = True
else:
is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool)

# Compute PP edges and features
if final_protein_pos.size(0) > 0:
pp_edge_index = radius_graph(final_protein_pos, r=self.cutoff, loop=False)
Expand All @@ -1071,6 +1119,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
"protein_pos": final_protein_pos,
"protein_x": final_protein_x,
"protein_res_idx": final_protein_res_idx,
"is_ligand": is_ligand,
"water_pos": water_pos,
"water_x": water_x,
# PP topology and features (precomputed)
Expand Down Expand Up @@ -1164,6 +1213,7 @@ def __getitem__(self, idx: int) -> HeteroData:
protein_pos = cached["protein_pos"]
protein_x = cached["protein_x"]
protein_res_idx = cached["protein_res_idx"]
is_ligand = cached["is_ligand"]
pp_edge_index = cached["pp_edge_index"]
pp_edge_unit_vectors = cached["pp_edge_unit_vectors"]
pp_edge_rbf = cached["pp_edge_rbf"]
Expand All @@ -1185,6 +1235,7 @@ def __getitem__(self, idx: int) -> HeteroData:
data["protein"].x = protein_x
data["protein"].pos = protein_pos
data["protein"].residue_index = protein_res_idx
data["protein"].is_ligand = is_ligand
data["protein"].num_nodes = protein_pos.size(0)
data["protein"].num_residues = num_residues
data["protein"].num_protein_residues = num_protein_residues
Expand Down
74 changes: 44 additions & 30 deletions src/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import torch
import torch.nn as nn

from src.constants import NODE_FEATURE_DIM


if TYPE_CHECKING:
from torch_geometric.data import HeteroData
Expand Down Expand Up @@ -142,54 +144,51 @@ class CachedEmbeddingEncoder(BaseProteinEncoder):
"""
Encoder for pre-computed protein embeddings (ESM, SLAE, etc.).

This pass-through encoder reads embeddings stored in HeteroData under a
specified key and returns them as scalar features. No neural network
computation occurs; all geometric processing happens in downstream layers.
This near pass-through encoder reads embeddings stored in HeteroData under a
specified key and returns them as scalar features. The only learnable
component is ligand_embed, a projection applied to ligand atoms (which have
no cached embedding); all geometric processing happens in downstream layers.

Supported embedding types:
- ESM: Evolutionary Scale Modeling embeddings (https://github.com/evolutionaryscale/esm)
- SLAE: Strictly Local Atom-level Environment Embeddings (https://www.biorxiv.org/content/10.1101/2025.10.03.680398v1)

Embedding dimension is inferred from the data on first forward pass.
Accessing output_dims before forward() raises RuntimeError.
Embedding dimension must be provided at construction so output_dims is
available immediately and ligand_embed can be created in __init__.

Memory: Embeddings are NOT loaded at initialization. The encoder stores
only the key name; actual embeddings are read from data at forward time,
allowing standard PyTorch batching/streaming.
Memory: Cached embeddings are NOT loaded at initialization. The encoder
stores only the key name; actual embeddings are read from data at forward
time, allowing standard PyTorch batching/streaming.

Note: Returns empty vector features (shape Nx0x3) since cached embeddings
are scalar-only.
"""

def __init__(
self, embedding_key: str, encoder_type: str, embedding_dim: int | None = None
):
def __init__(self, embedding_key: str, encoder_type: str, embedding_dim: int):
"""
Initialize CachedEmbeddingEncoder.

Args:
embedding_key: Key to look up embeddings in data['protein']
encoder_type: Encoder type identifier ('esm' or 'slae')
embedding_dim: Optional embedding dimension. If provided, output_dims is
available immediately. If None, dimension is inferred on first forward.
embedding_dim: Dimension of the cached embeddings. Required so that
output_dims is available immediately and the ligand projection can
be created in __init__ (see ligand_embed below).
"""
super().__init__()
self._embedding_dim: int | None = embedding_dim
self._embedding_dim: int = embedding_dim
self._embedding_key = embedding_key
self._encoder_type = encoder_type
# Learnable projection for ligand atoms (element one-hot -> embedding space).
# Ligands have no ESM/SLAE embeddings; this replaces zero-padding with a
# learned representation parameterized only by element type. Created here in
# __init__ (not forward) so its parameters are registered before the
# optimizer is built and are replicated/synchronized under DDP/FSDP.
self.ligand_embed = nn.Linear(NODE_FEATURE_DIM, embedding_dim, bias=False)

@property
def output_dims(self) -> tuple[int, int]:
"""Return (embedding_dim, 0) — scalars only.

Raises:
RuntimeError: If accessed before first forward pass (dimension not yet inferred)
"""
if self._embedding_dim is None:
raise RuntimeError(
f"{self._encoder_type.upper()} encoder dimension not yet known. "
"Run a forward pass first to infer dimension from data."
)
"""Return (embedding_dim, 0) — scalars only."""
return self._embedding_dim, 0

@property
Expand All @@ -203,13 +202,15 @@ def forward(
"""
Read cached embeddings and return (s, V, None).

On first call, infers embedding dimension from the data.
If ligand atoms are present (data['protein'].is_ligand), their zero-padded
embedding rows are replaced with a learned projection from element one-hot
features.

Args:
data: HeteroData with cached embeddings in data['protein']

Returns:
s: (N, embedding_dim) — raw embeddings
s: (N, embedding_dim) — embeddings (ligand rows via learned projection)
V: (N, 0, 3) — empty vector features
pp_edge_attr: None — cached embedding encoders don't process edges
"""
Expand All @@ -221,9 +222,15 @@ def forward(

embeddings = data["protein"][self._embedding_key]

# Infer dimension on first forward
if self._embedding_dim is None:
self._embedding_dim = embeddings.size(-1)
# Move the mask and node features onto the embeddings' device first:
# boolean indexing requires the mask and indexed tensor to share a device,
# which may not hold if a caller moved only embeddings (e.g. to GPU).
lig_mask = getattr(data["protein"], "is_ligand", None)
if lig_mask is not None and lig_mask.any():
lig_mask = lig_mask.to(embeddings.device)
x = data["protein"].x.to(embeddings.device)
embeddings = embeddings.clone()
embeddings[lig_mask] = self.ligand_embed(x[lig_mask])

V = embeddings.new_empty(embeddings.size(0), 0, 3)
return embeddings, V, None
Expand All @@ -237,13 +244,20 @@ def from_config(cls, config: dict, device: torch.device) -> CachedEmbeddingEncod
config: Configuration dictionary with:
- encoder_type: 'esm' or 'slae' (required)
- embedding_key: Optional key name (defaults to 'embedding')
- embedding_dim: Optional embedding dimension (if known upfront)
- embedding_dim: Embedding dimension (required)
device: Device to place the encoder on

Returns:
Instantiated CachedEmbeddingEncoder

Raises:
ValueError: If 'embedding_dim' is missing from config
"""
encoder_type = config["encoder_type"] # "esm" or "slae"
embedding_key = config.get("embedding_key", "embedding")
embedding_dim = config.get("embedding_dim")
if embedding_dim is None:
raise ValueError(
f"'{encoder_type}' encoder requires 'embedding_dim' in its config."
)
return cls(embedding_key, encoder_type, embedding_dim).to(device)
7 changes: 7 additions & 0 deletions src/gvp_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,13 @@ def _pool_by_residue(
Returns:
(num_residues, embed_dim) pooled residue embeddings
"""
# Ligand atoms carry residue_index = -1 (no parent residue). Drop them so
# scatter ops only ever see valid non-negative residue indices.
if (residue_index < 0).any():
valid = residue_index >= 0
atom_embed = atom_embed[valid]
residue_index = residue_index[valid]

aggr = self.pool_aggr
if aggr == "mean":
return scatter_mean(atom_embed, residue_index, dim=0, dim_size=num_residues)
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def pdb_1deu():
return _resolve_pdb_path("1deu")


@pytest.fixture
def pdb_4h0b():
"""4h0b - has non-water ligand HETATMs for ligand support tests."""
return _resolve_pdb_path("4h0b")


# ============== Shared encoder fixtures ==============


Expand Down
Loading