-
Notifications
You must be signed in to change notification settings - Fork 1
Adding ligand processing to dataset and encoders. #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
38c2e0b
b8568a5
7cf6de1
1529943
bebf742
c678fc0
e78be52
c824e2f
aca3180
6f14f4a
4979184
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,16 +47,21 @@ def element_onehot(symbols: list[str]) -> Tensor: | |
|
|
||
| def parse_asu_with_biotite( | ||
| path: str, | ||
| ) -> tuple[bts.AtomArray, bts.AtomArray]: | ||
| ) -> tuple[bts.AtomArray, bts.AtomArray, bts.AtomArray]: | ||
| """ | ||
| Parse PDB file and extract protein and water atoms. | ||
| Parse PDB file and extract protein, water, and ligand atoms. | ||
|
|
||
| Args: | ||
| path: Path to PDB file | ||
|
|
||
| Returns: | ||
| Tuple of (protein_atoms, water_atoms) as biotite AtomArrays. | ||
| Hydrogen atoms are excluded. | ||
| Tuple of (protein_atoms, water_atoms, ligand_atoms) as biotite AtomArrays. | ||
| Hydrogen atoms are excluded. ligand_atoms contains every non-protein, | ||
| non-water heavy atom: small-molecule ligands, ions, cofactors, AND | ||
| non-amino-acid polymers such as nucleic acids (DNA/RNA). It is deliberately | ||
| NOT restricted to HETATM records -- nucleic acids are written as ATOM | ||
| records but are kept here as context (their surfaces, especially the | ||
| phosphate backbone, order nearby water). | ||
|
|
||
| Notes: | ||
| - model=1: Uses first model in PDB (standard for X-ray structures) | ||
|
|
@@ -72,10 +77,15 @@ def parse_asu_with_biotite( | |
| protein_mask = bts.filter_amino_acids(atoms) | ||
| water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT") | ||
|
|
||
| # "ligand" here is broad: every non-protein, non-water heavy atom. | ||
| # includes small-molecule ligands, ions, cofactors and even nucleic acids | ||
| ligand_mask = ~protein_mask & ~water_mask | ||
|
|
||
| protein_atoms = atoms[protein_mask] | ||
| water_atoms = atoms[water_mask] | ||
| ligand_atoms = atoms[ligand_mask] | ||
|
|
||
| return protein_atoms, water_atoms | ||
| return protein_atoms, water_atoms, ligand_atoms | ||
|
|
||
|
|
||
| def get_crystal_contacts_pymol( | ||
|
|
@@ -665,6 +675,7 @@ def __init__( | |
| base_pdb_dir: str = "/sb/wankowicz_lab/data/srivasv/pdb_redo_data", | ||
| cutoff: float = 8.0, | ||
| include_mates: bool = True, | ||
| include_ligands: bool = True, | ||
| geometry_cache_name: str = "geometry", | ||
| preprocess: bool = True, | ||
| duplicate_single_sample: int = 1, | ||
|
|
@@ -691,10 +702,17 @@ def __init__( | |
| base_pdb_dir: Base directory containing PDB subdirectories | ||
| cutoff: Distance cutoff for PP edges and crystal contacts (Angstroms) | ||
| include_mates: If True, include symmetry mate atoms as protein nodes | ||
| include_ligands: If True (default), include every non-protein, | ||
| non-water heavy atom (small-molecule ligands, ions, | ||
| cofactors, and nucleic acids) as protein-type nodes. | ||
| They are appended after protein (and mate) atoms with a | ||
| boolean is_ligand mask and residue_index = -1. | ||
| geometry_cache_name: Base name for geometry cache directory. When | ||
| include_mates=True, "_mates" is appended automatically. | ||
| Default is "geometry", resulting in "geometry/" or | ||
| "geometry_mates/" subdirectories. | ||
| include_ligands does NOT affect the cache directory | ||
| name -- ligand inclusion is part of the dataset config, | ||
| not the cache path. Default is "geometry", yielding | ||
| "geometry/" or "geometry_mates/". | ||
| preprocess: If True, run preprocessing on missing cached files | ||
| duplicate_single_sample: If dataset has 1 sample, duplicate it this many times | ||
| Quality checks (always active): | ||
|
|
@@ -720,7 +738,9 @@ def __init__( | |
| """ | ||
|
|
||
| self.cache_dir = Path(processed_dir) | ||
| # Directory-based separation: geometry/ vs geometry_mates/ | ||
| # Directory-based separation: geometry/ vs geometry_mates/. Ligand inclusion | ||
| # is governed by the include_ligands config flag, not the cache directory | ||
| # name, so the geometry cache name is unaffected by include_ligands. | ||
| cache_suffix = "_mates" if include_mates else "" | ||
| self.geometry_dir = self.cache_dir / f"{geometry_cache_name}{cache_suffix}" | ||
| self.base_pdb_dir = Path(base_pdb_dir) | ||
|
Comment on lines
740
to
746
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. include_ligands should be on by default, not a concern here, both caches should have ligands
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this not a concern? I understand that the default is include_ligands=True so for your own caches this is fine, but if another user wants to toggle and try excluding ligands would there not be silent errors from loading caches? |
||
|
|
@@ -731,6 +751,7 @@ def __init__( | |
| else: | ||
| self.embedding_dir = None | ||
| self.include_mates = include_mates | ||
| self.include_ligands = include_ligands | ||
| self.duplicate_single_sample = duplicate_single_sample | ||
|
|
||
| self.max_com_dist = max_com_dist | ||
|
|
@@ -867,7 +888,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path): | |
| """ | ||
| pdb_path = str(entry["pdb_path"]) | ||
|
|
||
| protein_atoms, water_atoms = parse_asu_with_biotite(pdb_path) | ||
| protein_atoms, water_atoms, ligand_atoms = parse_asu_with_biotite(pdb_path) | ||
|
|
||
| # check inter-chain interactions for multi-chain proteins | ||
| chain_valid, chain_reason, _ = check_chain_interactions( | ||
|
|
@@ -1010,6 +1031,8 @@ def _preprocess_one(self, entry: dict, cache_path: Path): | |
| water_x = torch.zeros((0, len(ELEMENT_VOCAB) + 1), dtype=torch.float32) | ||
|
|
||
| # process symmetry mate atoms | ||
| # NOTE(ligands+mates): mate ligand het atoms belong here too when | ||
| # include_ligands is set -- see TODO at the ASU ligand-append block below. | ||
| mate_coords = crystal_data["mate_coords"] | ||
| if mate_coords.shape[0] > 0: | ||
| mate_pos = torch.tensor(mate_coords, dtype=torch.float32) - center | ||
|
|
@@ -1046,6 +1069,32 @@ def _preprocess_one(self, entry: dict, cache_path: Path): | |
| final_protein_x = protein_x | ||
| final_protein_res_idx = protein_res_idx | ||
|
|
||
| # Append ligand atoms after protein (and mate) atoms when enabled. | ||
| # is_ligand mask marks which protein-type nodes are ligand atoms. | ||
| # Ligands always go last so num_asu_protein and mate counts are unaffected, | ||
| # preserving ESM/SLAE embedding alignment via _pad_atom_embeddings_for_mates. | ||
|
|
||
| # TODO(ligands+mates): this only adds ASU ligands. Until dev_crystal_mates | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a bit confused by the comment. Is the intention (future PR) to add ligands as part of mate protein nodes or not? |
||
| # is opened for a PR, mates are restricted to polymer.protein, so a ligand sitting in a | ||
| # crystal contact is dropped from the mates | ||
|
vratins marked this conversation as resolved.
|
||
| if self.include_ligands and len(ligand_atoms) > 0: | ||
| ligand_pos = torch.tensor(ligand_atoms.coord, dtype=torch.float32) - center | ||
| ligand_elements = [str(e).upper() for e in ligand_atoms.element] | ||
| ligand_x = element_onehot(ligand_elements) | ||
| final_protein_pos = torch.cat([final_protein_pos, ligand_pos], dim=0) | ||
| final_protein_x = torch.cat([final_protein_x, ligand_x], dim=0) | ||
| # Ligand atoms get residue_index = -1 (sentinel; no residue embedding). | ||
| # The is_ligand mask identifies them; residue-pooling masks out these | ||
| # negative indices before any scatter (see GVPEncoder._pool_by_residue). | ||
| ligand_res_idx = torch.full((len(ligand_atoms),), -1, dtype=torch.long) | ||
| final_protein_res_idx = torch.cat( | ||
| [final_protein_res_idx, ligand_res_idx], dim=0 | ||
| ) | ||
| is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool) | ||
| is_ligand[-len(ligand_atoms) :] = True | ||
| else: | ||
| is_ligand = torch.zeros(final_protein_pos.size(0), dtype=torch.bool) | ||
|
|
||
| # Compute PP edges and features | ||
| if final_protein_pos.size(0) > 0: | ||
| pp_edge_index = radius_graph(final_protein_pos, r=self.cutoff, loop=False) | ||
|
|
@@ -1071,6 +1120,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path): | |
| "protein_pos": final_protein_pos, | ||
| "protein_x": final_protein_x, | ||
| "protein_res_idx": final_protein_res_idx, | ||
| "is_ligand": is_ligand, | ||
| "water_pos": water_pos, | ||
| "water_x": water_x, | ||
| # PP topology and features (precomputed) | ||
|
|
@@ -1164,6 +1214,7 @@ def __getitem__(self, idx: int) -> HeteroData: | |
| protein_pos = cached["protein_pos"] | ||
| protein_x = cached["protein_x"] | ||
| protein_res_idx = cached["protein_res_idx"] | ||
| is_ligand = cached["is_ligand"] | ||
| pp_edge_index = cached["pp_edge_index"] | ||
|
vratins marked this conversation as resolved.
|
||
| pp_edge_unit_vectors = cached["pp_edge_unit_vectors"] | ||
| pp_edge_rbf = cached["pp_edge_rbf"] | ||
|
|
@@ -1185,6 +1236,7 @@ def __getitem__(self, idx: int) -> HeteroData: | |
| data["protein"].x = protein_x | ||
| data["protein"].pos = protein_pos | ||
| data["protein"].residue_index = protein_res_idx | ||
| data["protein"].is_ligand = is_ligand | ||
| data["protein"].num_nodes = protein_pos.size(0) | ||
| data["protein"].num_residues = num_residues | ||
| data["protein"].num_protein_residues = num_protein_residues | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,8 @@ | |
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from src.constants import NODE_FEATURE_DIM | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from torch_geometric.data import HeteroData | ||
|
|
@@ -140,57 +142,73 @@ def from_config(cls, config: dict, device: torch.device) -> BaseProteinEncoder: | |
| @register_encoder("slae") | ||
| class CachedEmbeddingEncoder(BaseProteinEncoder): | ||
| """ | ||
| Encoder for pre-computed protein embeddings (ESM, SLAE, etc.). | ||
|
|
||
| This pass-through encoder reads embeddings stored in HeteroData under a | ||
| specified key and returns them as scalar features. No neural network | ||
| computation occurs; all geometric processing happens in downstream layers. | ||
| Fusion encoder for pre-computed protein embeddings (ESM, SLAE, etc.). | ||
|
|
||
| Each protein-type node is described by two modalities, both already present | ||
| on data['protein']: | ||
| - a cached sequence/structure embedding (`embedding`, width embedding_dim). | ||
| ASU protein atoms carry the real ESM/SLAE vector; symmetry mates and | ||
| ligand atoms are zero-padded (they have no residue embedding). | ||
|
Comment on lines
+150
to
+151
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not seeing relevant tests for testing if the embeddings work as expected -- whether should be zero-padded or not and have the one-hot component etc for ASU protein, mate, and ligands? I know the mate is not part of this PR, but probably should start adding tests to catch the kind of bug that you found that prevented crystal contact to work. |
||
| - a per-atom element one-hot (`x`, width NODE_FEATURE_DIM). | ||
|
|
||
| The two are each projected to fusion_dim, concatenated, and passed through a | ||
| small MLP to produce per-node scalar features. This gives every atom -- not | ||
| just ligands -- its own element identity (ESM is per-residue, so all atoms of | ||
| a residue otherwise share an identical vector), and handles ligands/mates | ||
| uniformly: their zero-ESM rows simply contribute nothing from esm_proj, so | ||
| their fused features are element-driven. No ligand/mate special-casing. | ||
|
|
||
| Supported embedding types: | ||
| - ESM: Evolutionary Scale Modeling embeddings (https://github.com/evolutionaryscale/esm) | ||
| - SLAE: Strictly Local Atom-level Environment Embeddings (https://www.biorxiv.org/content/10.1101/2025.10.03.680398v1) | ||
|
|
||
| Embedding dimension is inferred from the data on first forward pass. | ||
| Accessing output_dims before forward() raises RuntimeError. | ||
| Memory: Cached embeddings are NOT loaded at initialization. The encoder | ||
| stores only the key name; actual embeddings are read from data at forward | ||
| time, allowing standard PyTorch batching/streaming. | ||
|
|
||
| Memory: Embeddings are NOT loaded at initialization. The encoder stores | ||
| only the key name; actual embeddings are read from data at forward time, | ||
| allowing standard PyTorch batching/streaming. | ||
|
|
||
| Note: Returns empty vector features (shape Nx0x3) since cached embeddings | ||
| are scalar-only. | ||
| Note: Returns empty vector features (shape Nx0x3); fused output is scalar-only. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, embedding_key: str, encoder_type: str, embedding_dim: int | None = None | ||
| self, | ||
| embedding_key: str, | ||
| encoder_type: str, | ||
| embedding_dim: int, | ||
| fusion_dim: int, | ||
| ): | ||
| """ | ||
| Initialize CachedEmbeddingEncoder. | ||
|
|
||
| Args: | ||
| embedding_key: Key to look up embeddings in data['protein'] | ||
| encoder_type: Encoder type identifier ('esm' or 'slae') | ||
| embedding_dim: Optional embedding dimension. If provided, output_dims is | ||
| available immediately. If None, dimension is inferred on first forward. | ||
| embedding_dim: Width of the cached embeddings (e.g. 1536 for ESM, | ||
| 128 for SLAE). Validated against the data at forward time. | ||
| fusion_dim: Output width of the fused scalar features (output_dims[0]). | ||
| """ | ||
| super().__init__() | ||
| self._embedding_dim: int | None = embedding_dim | ||
| self._embedding_dim: int = embedding_dim | ||
| self._fusion_dim: int = fusion_dim | ||
| self._embedding_key = embedding_key | ||
| self._encoder_type = encoder_type | ||
| # Project each modality to fusion_dim, LayerNorm each stream, then fuse. | ||
| # Separate projections + per-stream norm keep the 16-dim element signal from | ||
| # being swamped by the wide, large-norm ESM vector (raw ESM norms are ~1e3-1e4 | ||
| # vs 1 for a one-hot), so both modalities enter the fuse MLP at comparable scale. | ||
| self.esm_proj = nn.Linear(embedding_dim, fusion_dim) | ||
| self.elem_proj = nn.Linear(NODE_FEATURE_DIM, fusion_dim) | ||
|
vratins marked this conversation as resolved.
|
||
| self.esm_norm = nn.LayerNorm(fusion_dim) | ||
| self.elem_norm = nn.LayerNorm(fusion_dim) | ||
| self.fuse = nn.Sequential( | ||
| nn.Linear(2 * fusion_dim, fusion_dim), | ||
| nn.SiLU(), | ||
| nn.Linear(fusion_dim, fusion_dim), | ||
| ) | ||
|
|
||
| @property | ||
| def output_dims(self) -> tuple[int, int]: | ||
| """Return (embedding_dim, 0) — scalars only. | ||
|
|
||
| Raises: | ||
| RuntimeError: If accessed before first forward pass (dimension not yet inferred) | ||
| """ | ||
| if self._embedding_dim is None: | ||
| raise RuntimeError( | ||
| f"{self._encoder_type.upper()} encoder dimension not yet known. " | ||
| "Run a forward pass first to infer dimension from data." | ||
| ) | ||
| return self._embedding_dim, 0 | ||
| """Return (fusion_dim, 0) — scalars only.""" | ||
| return self._fusion_dim, 0 | ||
|
|
||
| @property | ||
| def encoder_type(self) -> str: | ||
|
|
@@ -201,17 +219,16 @@ def forward( | |
| self, data: HeteroData | ||
| ) -> tuple[torch.Tensor, torch.Tensor, tuple | None]: | ||
| """ | ||
| Read cached embeddings and return (s, V, None). | ||
|
|
||
| On first call, infers embedding dimension from the data. | ||
| Fuse the cached embedding with the element one-hot and return (s, V, None). | ||
|
|
||
| Args: | ||
| data: HeteroData with cached embeddings in data['protein'] | ||
| data: HeteroData with cached embeddings and element features in | ||
| data['protein'] ('embedding' and 'x'). | ||
|
|
||
| Returns: | ||
| s: (N, embedding_dim) — raw embeddings | ||
| V: (N, 0, 3) — empty vector features | ||
| pp_edge_attr: None — cached embedding encoders don't process edges | ||
| s: (N, fusion_dim) — fused scalar features | ||
| V: (N, 0, 3) — empty vector features | ||
| pp_edge_attr: None — cached embedding encoders don't process edges | ||
| """ | ||
| if self._embedding_key not in data["protein"]: | ||
| raise KeyError( | ||
|
|
@@ -221,12 +238,23 @@ def forward( | |
|
|
||
| embeddings = data["protein"][self._embedding_key] | ||
|
|
||
| # Infer dimension on first forward | ||
| if self._embedding_dim is None: | ||
| self._embedding_dim = embeddings.size(-1) | ||
| # Validate cached width against the configured embedding_dim, otherwise a | ||
| # mismatched cache fails inside esm_proj with an opaque shape error. | ||
| if embeddings.size(-1) != self._embedding_dim: | ||
| raise ValueError( | ||
| f"{self._encoder_type.upper()} encoder configured with " | ||
| f"embedding_dim={self._embedding_dim}, but cached " | ||
| f"'{self._embedding_key}' embeddings have width {embeddings.size(-1)}. " | ||
| f"Ensure the encoder config matches the cached embeddings." | ||
| ) | ||
|
|
||
| x = data["protein"].x.to(embeddings.device) | ||
| esm = self.esm_norm(self.esm_proj(embeddings)) | ||
| elem = self.elem_norm(self.elem_proj(x)) | ||
| fused = self.fuse(torch.cat([esm, elem], dim=-1)) | ||
|
|
||
| V = embeddings.new_empty(embeddings.size(0), 0, 3) | ||
| return embeddings, V, None | ||
| V = fused.new_empty(fused.size(0), 0, 3) | ||
| return fused, V, None | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: dict, device: torch.device) -> CachedEmbeddingEncoder: | ||
|
|
@@ -237,13 +265,27 @@ def from_config(cls, config: dict, device: torch.device) -> CachedEmbeddingEncod | |
| config: Configuration dictionary with: | ||
| - encoder_type: 'esm' or 'slae' (required) | ||
| - embedding_key: Optional key name (defaults to 'embedding') | ||
| - embedding_dim: Optional embedding dimension (if known upfront) | ||
| - embedding_dim: Cached embedding width (required) | ||
| - hidden_s: Fused output width / scalar hidden dim (required) | ||
| device: Device to place the encoder on | ||
|
|
||
| Returns: | ||
| Instantiated CachedEmbeddingEncoder | ||
|
|
||
| Raises: | ||
| ValueError: If 'embedding_dim' or 'hidden_s' is missing from config | ||
| """ | ||
| encoder_type = config["encoder_type"] # "esm" or "slae" | ||
| embedding_key = config.get("embedding_key", "embedding") | ||
| embedding_dim = config.get("embedding_dim") | ||
| return cls(embedding_key, encoder_type, embedding_dim).to(device) | ||
| if embedding_dim is None: | ||
| raise ValueError( | ||
| f"'{encoder_type}' encoder requires 'embedding_dim' in its config." | ||
| ) | ||
| fusion_dim = config.get("hidden_s") | ||
| if fusion_dim is None: | ||
| raise ValueError( | ||
| f"'{encoder_type}' encoder requires 'hidden_s' (fused output width) " | ||
| f"in its config." | ||
| ) | ||
| return cls(embedding_key, encoder_type, embedding_dim, fusion_dim).to(device) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out side of the diff, but just hard-coded path to remove