Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"e3nn",
"esm",
"biotite",
"pymol-open-source",
"pymol-open-source-whl>=3.1.0.4",
"scipy",
"pandas",
"numpy",
Expand Down
2 changes: 1 addition & 1 deletion scripts/generate_esm_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def compute_esm_embeddings(
"""
try:
# Load ground truth atoms using geometry cache parser in src/dataset.py
protein_atoms, _ = parse_asu_with_biotite(str(pdb_path))
protein_atoms, _, _ = parse_asu_with_biotite(str(pdb_path))
if len(protein_atoms) == 0:
raise ValueError(f"No protein atoms found in {pdb_path}")

Expand Down
6 changes: 5 additions & 1 deletion scripts/generate_slae_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
Precompute SLAE embeddings for protein structures and save to separate cache files.

NOTE: This SLAE encoder is legacy and is NOT currently used. We primarily use the
ESM encoder (see scripts/generate_esm_embeddings.py). This script is retained for
reference/reproducibility only.

This script:
1. Reads a split file containing PDB entries
2. For each entry, loads the PDB and converts to atom37 representation
Expand Down Expand Up @@ -400,7 +404,7 @@ def main() -> None:

try:
# protein_atoms: biotite AtomArray with num_atoms atoms
protein_atoms, _ = parse_asu_with_biotite(str(pdb_path))
protein_atoms, _, _ = parse_asu_with_biotite(str(pdb_path))
# coords: (num_residues, 37, 3) - atom37 coordinates
# residue_type: (num_residues,) - residue type indices
# chains: (num_residues,) - chain IDs
Expand Down
70 changes: 61 additions & 9 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,21 @@ def element_onehot(symbols: list[str]) -> Tensor:

def parse_asu_with_biotite(
path: str,
) -> tuple[bts.AtomArray, bts.AtomArray]:
) -> tuple[bts.AtomArray, bts.AtomArray, bts.AtomArray]:
"""
Parse PDB file and extract protein and water atoms.
Parse PDB file and extract protein, water, and ligand atoms.

Args:
path: Path to PDB file

Returns:
Tuple of (protein_atoms, water_atoms) as biotite AtomArrays.
Hydrogen atoms are excluded.
Tuple of (protein_atoms, water_atoms, ligand_atoms) as biotite AtomArrays.
Hydrogen atoms are excluded. ligand_atoms contains every non-protein,
non-water heavy atom: small-molecule ligands, ions, cofactors, AND
non-amino-acid polymers such as nucleic acids (DNA/RNA). It is deliberately
NOT restricted to HETATM records -- nucleic acids are written as ATOM
records but are kept here as context (their surfaces, especially the
phosphate backbone, order nearby water).

Notes:
- model=1: Uses first model in PDB (standard for X-ray structures)
Expand All @@ -72,10 +77,15 @@ def parse_asu_with_biotite(
protein_mask = bts.filter_amino_acids(atoms)
water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")

# "ligand" here is broad: every non-protein, non-water heavy atom.
# includes small-molecule ligands, ions, cofactors and even nucleic acids
ligand_mask = ~protein_mask & ~water_mask

protein_atoms = atoms[protein_mask]
water_atoms = atoms[water_mask]
ligand_atoms = atoms[ligand_mask]

return protein_atoms, water_atoms
return protein_atoms, water_atoms, ligand_atoms


def get_crystal_contacts_pymol(
Expand Down Expand Up @@ -665,6 +675,7 @@ def __init__(
base_pdb_dir: str = "/sb/wankowicz_lab/data/srivasv/pdb_redo_data",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out side of the diff, but just hard-coded path to remove

cutoff: float = 8.0,
include_mates: bool = True,
include_ligands: bool = True,
geometry_cache_name: str = "geometry",
preprocess: bool = True,
duplicate_single_sample: int = 1,
Expand All @@ -691,10 +702,17 @@ def __init__(
base_pdb_dir: Base directory containing PDB subdirectories
cutoff: Distance cutoff for PP edges and crystal contacts (Angstroms)
include_mates: If True, include symmetry mate atoms as protein nodes
include_ligands: If True (default), include every non-protein,
non-water heavy atom (small-molecule ligands, ions,
cofactors, and nucleic acids) as protein-type nodes.
They are appended after protein (and mate) atoms with a
boolean is_ligand mask and residue_index = -1.
geometry_cache_name: Base name for geometry cache directory. When
include_mates=True, "_mates" is appended automatically.
Default is "geometry", resulting in "geometry/" or
"geometry_mates/" subdirectories.
include_ligands does NOT affect the cache directory
name -- ligand inclusion is part of the dataset config,
not the cache path. Default is "geometry", yielding
"geometry/" or "geometry_mates/".
preprocess: If True, run preprocessing on missing cached files
duplicate_single_sample: If dataset has 1 sample, duplicate it this many times
Quality checks (always active):
Expand All @@ -720,7 +738,9 @@ def __init__(
"""

self.cache_dir = Path(processed_dir)
# Directory-based separation: geometry/ vs geometry_mates/
# Directory-based separation: geometry/ vs geometry_mates/. Ligand inclusion
# is governed by the include_ligands config flag, not the cache directory
# name, so the geometry cache name is unaffected by include_ligands.
cache_suffix = "_mates" if include_mates else ""
self.geometry_dir = self.cache_dir / f"{geometry_cache_name}{cache_suffix}"
self.base_pdb_dir = Path(base_pdb_dir)
Comment on lines 740 to 746

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include_ligands should be on by default, not a concern here, both caches should have ligands

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not a concern? I understand that the default is include_ligands=True so for your own caches this is fine, but if another user wants to toggle and try excluding ligands would there not be silent errors from loading caches?

Expand All @@ -731,6 +751,7 @@ def __init__(
else:
self.embedding_dir = None
self.include_mates = include_mates
self.include_ligands = include_ligands
self.duplicate_single_sample = duplicate_single_sample

self.max_com_dist = max_com_dist
Expand Down Expand Up @@ -867,7 +888,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
"""
pdb_path = str(entry["pdb_path"])

protein_atoms, water_atoms = parse_asu_with_biotite(pdb_path)
protein_atoms, water_atoms, ligand_atoms = parse_asu_with_biotite(pdb_path)

# check inter-chain interactions for multi-chain proteins
chain_valid, chain_reason, _ = check_chain_interactions(
Expand Down Expand Up @@ -1010,6 +1031,8 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
water_x = torch.zeros((0, len(ELEMENT_VOCAB) + 1), dtype=torch.float32)

# process symmetry mate atoms
# NOTE(ligands+mates): mate ligand het atoms belong here too when
# include_ligands is set -- see TODO at the ASU ligand-append block below.
mate_coords = crystal_data["mate_coords"]
if mate_coords.shape[0] > 0:
mate_pos = torch.tensor(mate_coords, dtype=torch.float32) - center
Expand Down Expand Up @@ -1046,6 +1069,32 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
final_protein_x = protein_x
final_protein_res_idx = protein_res_idx

# Append ligand atoms after protein (and mate) atoms when enabled.
# is_ligand mask marks which protein-type nodes are ligand atoms.
# Ligands always go last so num_asu_protein and mate counts are unaffected,
# preserving ESM/SLAE embedding alignment via _pad_atom_embeddings_for_mates.

# TODO(ligands+mates): this only adds ASU ligands. Until dev_crystal_mates

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused by the comment. Is the intention (future PR) to add ligands as part of mate protein nodes or not?

# is opened for a PR, mates are restricted to polymer.protein, so a ligand sitting in a
# crystal contact is dropped from the mates
Comment thread
vratins marked this conversation as resolved.
if self.include_ligands and len(ligand_atoms) > 0:
ligand_pos = torch.tensor(ligand_atoms.coord, dtype=torch.float32) - center
ligand_elements = [str(e).upper() for e in ligand_atoms.element]
ligand_x = element_onehot(ligand_elements)
final_protein_pos = torch.cat([final_protein_pos, ligand_pos], dim=0)
final_protein_x = torch.cat([final_protein_x, ligand_x], dim=0)
# Ligand atoms get residue_index = -1 (sentinel; no residue embedding).
# The is_ligand mask identifies them; residue-pooling masks out these
# negative indices before any scatter (see GVPEncoder._pool_by_residue).
ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long)
final_protein_res_idx = torch.cat(
[final_protein_res_idx, ligand_res_idx], dim=0
)
is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool)
is_ligand[-len(ligand_atoms) :] = True
else:
is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool)

# Compute PP edges and features
if final_protein_pos.size(0) > 0:
pp_edge_index = radius_graph(final_protein_pos, r=self.cutoff, loop=False)
Expand All @@ -1071,6 +1120,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
"protein_pos": final_protein_pos,
"protein_x": final_protein_x,
"protein_res_idx": final_protein_res_idx,
"is_ligand": is_ligand,
"water_pos": water_pos,
"water_x": water_x,
# PP topology and features (precomputed)
Expand Down Expand Up @@ -1164,6 +1214,7 @@ def __getitem__(self, idx: int) -> HeteroData:
protein_pos = cached["protein_pos"]
protein_x = cached["protein_x"]
protein_res_idx = cached["protein_res_idx"]
is_ligand = cached["is_ligand"]
pp_edge_index = cached["pp_edge_index"]
Comment thread
vratins marked this conversation as resolved.
pp_edge_unit_vectors = cached["pp_edge_unit_vectors"]
pp_edge_rbf = cached["pp_edge_rbf"]
Expand All @@ -1185,6 +1236,7 @@ def __getitem__(self, idx: int) -> HeteroData:
data["protein"].x = protein_x
data["protein"].pos = protein_pos
data["protein"].residue_index = protein_res_idx
data["protein"].is_ligand = is_ligand
data["protein"].num_nodes = protein_pos.size(0)
data["protein"].num_residues = num_residues
data["protein"].num_protein_residues = num_protein_residues
Expand Down
126 changes: 84 additions & 42 deletions src/encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import torch
import torch.nn as nn

from src.constants import NODE_FEATURE_DIM


if TYPE_CHECKING:
from torch_geometric.data import HeteroData
Expand Down Expand Up @@ -140,57 +142,73 @@ def from_config(cls, config: dict, device: torch.device) -> BaseProteinEncoder:
@register_encoder("slae")
class CachedEmbeddingEncoder(BaseProteinEncoder):
"""
Encoder for pre-computed protein embeddings (ESM, SLAE, etc.).

This pass-through encoder reads embeddings stored in HeteroData under a
specified key and returns them as scalar features. No neural network
computation occurs; all geometric processing happens in downstream layers.
Fusion encoder for pre-computed protein embeddings (ESM, SLAE, etc.).

Each protein-type node is described by two modalities, both already present
on data['protein']:
- a cached sequence/structure embedding (`embedding`, width embedding_dim).
ASU protein atoms carry the real ESM/SLAE vector; symmetry mates and
ligand atoms are zero-padded (they have no residue embedding).
Comment on lines +150 to +151

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not seeing relevant tests for testing if the embeddings work as expected -- whether should be zero-padded or not and have the one-hot component etc for ASU protein, mate, and ligands? I know the mate is not part of this PR, but probably should start adding tests to catch the kind of bug that you found that prevented crystal contact to work.

- a per-atom element one-hot (`x`, width NODE_FEATURE_DIM).

The two are each projected to fusion_dim, concatenated, and passed through a
small MLP to produce per-node scalar features. This gives every atom -- not
just ligands -- its own element identity (ESM is per-residue, so all atoms of
a residue otherwise share an identical vector), and handles ligands/mates
uniformly: their zero-ESM rows simply contribute nothing from esm_proj, so
their fused features are element-driven. No ligand/mate special-casing.

Supported embedding types:
- ESM: Evolutionary Scale Modeling embeddings (https://github.com/evolutionaryscale/esm)
- SLAE: Strictly Local Atom-level Environment Embeddings (https://www.biorxiv.org/content/10.1101/2025.10.03.680398v1)

Embedding dimension is inferred from the data on first forward pass.
Accessing output_dims before forward() raises RuntimeError.
Memory: Cached embeddings are NOT loaded at initialization. The encoder
stores only the key name; actual embeddings are read from data at forward
time, allowing standard PyTorch batching/streaming.

Memory: Embeddings are NOT loaded at initialization. The encoder stores
only the key name; actual embeddings are read from data at forward time,
allowing standard PyTorch batching/streaming.

Note: Returns empty vector features (shape Nx0x3) since cached embeddings
are scalar-only.
Note: Returns empty vector features (shape Nx0x3); fused output is scalar-only.
"""

def __init__(
self, embedding_key: str, encoder_type: str, embedding_dim: int | None = None
self,
embedding_key: str,
encoder_type: str,
embedding_dim: int,
fusion_dim: int,
):
"""
Initialize CachedEmbeddingEncoder.

Args:
embedding_key: Key to look up embeddings in data['protein']
encoder_type: Encoder type identifier ('esm' or 'slae')
embedding_dim: Optional embedding dimension. If provided, output_dims is
available immediately. If None, dimension is inferred on first forward.
embedding_dim: Width of the cached embeddings (e.g. 1536 for ESM,
128 for SLAE). Validated against the data at forward time.
fusion_dim: Output width of the fused scalar features (output_dims[0]).
"""
super().__init__()
self._embedding_dim: int | None = embedding_dim
self._embedding_dim: int = embedding_dim
self._fusion_dim: int = fusion_dim
self._embedding_key = embedding_key
self._encoder_type = encoder_type
# Project each modality to fusion_dim, LayerNorm each stream, then fuse.
# Separate projections + per-stream norm keep the 16-dim element signal from
# being swamped by the wide, large-norm ESM vector (raw ESM norms are ~1e3-1e4
# vs 1 for a one-hot), so both modalities enter the fuse MLP at comparable scale.
self.esm_proj = nn.Linear(embedding_dim, fusion_dim)
self.elem_proj = nn.Linear(NODE_FEATURE_DIM, fusion_dim)
Comment thread
vratins marked this conversation as resolved.
self.esm_norm = nn.LayerNorm(fusion_dim)
self.elem_norm = nn.LayerNorm(fusion_dim)
self.fuse = nn.Sequential(
nn.Linear(2 * fusion_dim, fusion_dim),
nn.SiLU(),
nn.Linear(fusion_dim, fusion_dim),
)

@property
def output_dims(self) -> tuple[int, int]:
"""Return (embedding_dim, 0) — scalars only.

Raises:
RuntimeError: If accessed before first forward pass (dimension not yet inferred)
"""
if self._embedding_dim is None:
raise RuntimeError(
f"{self._encoder_type.upper()} encoder dimension not yet known. "
"Run a forward pass first to infer dimension from data."
)
return self._embedding_dim, 0
"""Return (fusion_dim, 0) — scalars only."""
return self._fusion_dim, 0

@property
def encoder_type(self) -> str:
Expand All @@ -201,17 +219,16 @@ def forward(
self, data: HeteroData
) -> tuple[torch.Tensor, torch.Tensor, tuple | None]:
"""
Read cached embeddings and return (s, V, None).

On first call, infers embedding dimension from the data.
Fuse the cached embedding with the element one-hot and return (s, V, None).

Args:
data: HeteroData with cached embeddings in data['protein']
data: HeteroData with cached embeddings and element features in
data['protein'] ('embedding' and 'x').

Returns:
s: (N, embedding_dim) — raw embeddings
V: (N, 0, 3) — empty vector features
pp_edge_attr: None — cached embedding encoders don't process edges
s: (N, fusion_dim) — fused scalar features
V: (N, 0, 3) — empty vector features
pp_edge_attr: None — cached embedding encoders don't process edges
"""
if self._embedding_key not in data["protein"]:
raise KeyError(
Expand All @@ -221,12 +238,23 @@ def forward(

embeddings = data["protein"][self._embedding_key]

# Infer dimension on first forward
if self._embedding_dim is None:
self._embedding_dim = embeddings.size(-1)
# Validate cached width against the configured embedding_dim, otherwise a
# mismatched cache fails inside esm_proj with an opaque shape error.
if embeddings.size(-1) != self._embedding_dim:
raise ValueError(
f"{self._encoder_type.upper()} encoder configured with "
f"embedding_dim={self._embedding_dim}, but cached "
f"'{self._embedding_key}' embeddings have width {embeddings.size(-1)}. "
f"Ensure the encoder config matches the cached embeddings."
)

x = data["protein"].x.to(embeddings.device)
esm = self.esm_norm(self.esm_proj(embeddings))
elem = self.elem_norm(self.elem_proj(x))
fused = self.fuse(torch.cat([esm, elem], dim=-1))

V = embeddings.new_empty(embeddings.size(0), 0, 3)
return embeddings, V, None
V = fused.new_empty(fused.size(0), 0, 3)
return fused, V, None

@classmethod
def from_config(cls, config: dict, device: torch.device) -> CachedEmbeddingEncoder:
Expand All @@ -237,13 +265,27 @@ def from_config(cls, config: dict, device: torch.device) -> CachedEmbeddingEncod
config: Configuration dictionary with:
- encoder_type: 'esm' or 'slae' (required)
- embedding_key: Optional key name (defaults to 'embedding')
- embedding_dim: Optional embedding dimension (if known upfront)
- embedding_dim: Cached embedding width (required)
- hidden_s: Fused output width / scalar hidden dim (required)
device: Device to place the encoder on

Returns:
Instantiated CachedEmbeddingEncoder

Raises:
ValueError: If 'embedding_dim' or 'hidden_s' is missing from config
"""
encoder_type = config["encoder_type"] # "esm" or "slae"
embedding_key = config.get("embedding_key", "embedding")
embedding_dim = config.get("embedding_dim")
return cls(embedding_key, encoder_type, embedding_dim).to(device)
if embedding_dim is None:
raise ValueError(
f"'{encoder_type}' encoder requires 'embedding_dim' in its config."
)
fusion_dim = config.get("hidden_s")
if fusion_dim is None:
raise ValueError(
f"'{encoder_type}' encoder requires 'hidden_s' (fused output width) "
f"in its config."
)
return cls(embedding_key, encoder_type, embedding_dim, fusion_dim).to(device)
Loading
Loading