Skip to content

KNN edge construction fix + align residue indexing with ESM sanitization#82

Open
vratins wants to merge 7 commits into
mainfrom
dev_knn_fix
Open

KNN edge construction fix + align residue indexing with ESM sanitization#82
vratins wants to merge 7 commits into
mainfrom
dev_knn_fix

Conversation

@vratins

@vratins vratins commented Jun 17, 2026

Copy link
Copy Markdown
Contributor
  • build_knn_edges was calling knn(x=dst_pos, y=src_pos), which queries each src point's nearest dst points instead of each dst point's nearest src points; swapped the call and the resulting index rows to fix this. This was previously masked by taking the union of edges on both sides. A future PR will use KNN as a fallback to radius graphs. Expanded the docstring to document that the query is per-destination, so every destination is guaranteed incoming edges (row 1) while a source that is nobody's nearest neighbor may be absent from row 0 — and updated the water-water coverage tests to assert on the destination row (row 1) accordingly, removing the now-unnecessary xfail markers.
  • match_atoms_to_coords now also handles an empty atoms array instead of only an empty target_coords array.
  • Residue indices for the protein graph are now computed via the same residue-name canonicalization the ESM embedding generation script uses (THREE_TO_ONE -> ONE_TO_THREE, unknowns -> UNK) before counting residue boundaries with biotite's get_residue_starts. Without this, two residues that share (chain, resid, ins_code) but had different original res_names could get merged into one under ESM's sanitization but stay separate here, desyncing residue counts/indices from the stored ESM embeddings. Insertion codes are now also normalized (normalize_ins_code) before counting, since get_residue_starts splits on ins_code too — a blank vs. placeholder code ('' / '.' / '?') would otherwise split or merge residues differently from the ESM script's residue keys. Both the canonicalization (sanitize_res_names_for_esm) and the insertion-code normalization are now shared between src/dataset.py and scripts/generate_esm_embeddings.py to prevent the two paths from drifting apart.
  • Cleanup: Removed a dead duplicate build_knn_edges in src/utils.py (the canonical one lives in src/flow.py) that still carried the old per-source semantics, and removed the orphaned atom37_to_atoms / ATOM37_FILL helper (unused outside its own tests; the SLAE pipeline uses its own implementation). Added regression tests covering residue-count alignment between the dataset and the ESM residue keys.

Summary by CodeRabbit

  • Bug Fixes

    • Improved residue-index alignment by using shared ESM-compatible residue-name sanitization for both embedding generation and dataset preprocessing.
    • Hardened atom-to-coordinate matching to safely handle empty inputs.
    • Corrected KNN edge construction direction to consistently follow the documented source/destination convention (improving downstream water connectivity).
  • Tests

    • Updated water-water edge assertions to match destination/query edge coverage.
    • Added/expanded tests for ESM residue-name sanitization behavior.
  • Documentation

    • Updated dataset description to reflect water-position prediction.

Copilot AI review requested due to automatic review settings June 17, 2026 04:23
@coderabbitai

coderabbitai Bot commented Jun 17, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

The PR adds shared ESM residue-name sanitization, uses it in embedding generation and dataset residue indexing, and updates KNN edge construction to emit source-to-destination directed edges. Related tests now assert the updated residue and edge semantics.

Changes

ESM Residue Sanitization Alignment

Layer / File(s) Summary
Utility and embedding sanitation
src/utils.py, scripts/generate_esm_embeddings.py
sanitize_res_names_for_esm is added, residue-name mapping imports are updated, and the embedding script now uses the साझा helper before clearing hetero flags.
Dataset residue indexing
src/dataset.py
Protein atoms are sanitized before residue-start counting, atom-to-residue indices are assigned with np.searchsorted, and the empty-atom guard and dataset docstring are updated.
Sanitization test coverage
tests/test_utils.py
The utility tests now cover residue-name canonicalization, insertion-code normalization behavior, and alignment between sanitized residue starts and ESM-style residue-key counting.

KNN Edge Direction Correction

Layer / File(s) Summary
Directed edge construction
src/flow.py
build_knn_edges now queries KNN with source and destination positions arranged to produce directed edges with source in row 0 and destination in row 1.
Water-edge assertions
tests/test_flow.py
The water-edge tests now assert coverage using the destination/query row for both single-graph and batched cases.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~40 minutes

Poem

🐇 I hop through residues, neat and clean,
With ESM-style names in between.
Edges now point where they ought to go,
Row one says “dest,” row zero says “so!”
The tests do a twirl, the buffers all sing,
And bunny-approved correctness is now in spring.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Clearly summarizes the main KNN edge fix and ESM residue-index alignment; concise and specific.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch dev_knn_fix

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes KNN-based edge construction directionality in the flow model, hardens dataset preprocessing for edge cases, and aligns protein residue indexing with the same residue-name sanitization used when generating cached ESM embeddings.

Changes:

  • Fix build_knn_edges to query per-destination point (and swap returned index rows) so edges align with intended src→dst semantics.
  • Update dataset preprocessing to handle empty atoms inputs and to compute residue indices using ESM-style residue-name canonicalization before residue-boundary detection.
  • Adjust tests by adding an xfail marker for a batched edge-connectivity test (though the WW coverage assertions likely need updating instead).

Reviewed changes

Copilot reviewed 3 out of 4 changed files in this pull request and generated 6 comments.

File Description
uv.lock Updates locked dependencies (adds jaxtyping, removes mypy, adjusts some wheels/metadata).
tests/test_flow.py Marks the batched water-edge connectivity test as xfail(strict=True) and updates rationale text.
src/flow.py Fixes KNN query argument order and explicitly swaps index rows to preserve src→dst edge_index layout.
src/dataset.py Handles empty atoms in coordinate matching; updates residue indexing to mirror ESM sanitization behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/flow.py
Comment thread tests/test_flow.py Outdated
Comment thread tests/test_flow.py Outdated
Comment thread tests/test_flow.py Outdated
Comment thread src/dataset.py Outdated
Comment thread src/dataset.py Outdated

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/test_flow.py (1)

582-617: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Narrow the xfail scope so it doesn’t hide pw regressions.

xfail is currently applied to the whole test, so failures in the protein-water assertions are also treated as expected. That drops useful coverage beyond the known ww issue.

Suggested split to preserve `pw` coverage
-    `@pytest.mark.xfail`(
+    def test_batched_waters_have_protein_edges(self, batched_hetero_data):
+        """Ensure all waters in a batched graph have protein-water edges."""
+        updater = ProteinWaterUpdate(hidden_dims=(128, 16), layers=1)
+        edge_dict = updater.build_edges(batched_hetero_data, k_pw=4, k_ww=3)
+        pw_edges = edge_dict[("protein", "pw", "water")]
+        n_water = batched_hetero_data["water"].num_nodes
+        water_nodes_with_pw_edges = torch.unique(pw_edges[1])
+        assert len(water_nodes_with_pw_edges) == n_water, (
+            f"Only {len(water_nodes_with_pw_edges)}/{n_water} waters have protein edges in batched data"
+        )
+
+    `@pytest.mark.xfail`(
         reason=(
             "build_knn_edges' src/dst argument-order fix changes self-graph (ww) "
             "edge direction: row 0 now holds discovered neighbors rather than query "
             "points, so a point that is nobody's k-nearest neighbor can be dropped "
             "from coverage. The fixed-degree k_pw/k_ww KNN approach is replaced by "
             "radius-based edges + KNN-fallback-for-isolated-nodes in a future PR "
             "(edge type flags & dynamic edge construction), which removes the "
             "k_pw/k_ww params and fixes this guarantee structurally. will remove this "
             "marker when that PR is created."
         ),
         strict=True,
     )
-    def test_batched_waters_have_edges(self, batched_hetero_data):
-        """Ensure all waters in a batched graph have edges."""
+    def test_batched_waters_have_water_edges(self, batched_hetero_data):
+        """Ensure all waters in a batched graph have water-water edges."""
         updater = ProteinWaterUpdate(hidden_dims=(128, 16), layers=1)
-
         edge_dict = updater.build_edges(batched_hetero_data, k_pw=4, k_ww=3)
-        pw_edges = edge_dict[("protein", "pw", "water")]
         ww_edges = edge_dict[("water", "ww", "water")]
-
         n_water = batched_hetero_data["water"].num_nodes
-
-        # Check protein-water edges
-        water_nodes_with_pw_edges = torch.unique(pw_edges[1])
-        assert len(water_nodes_with_pw_edges) == n_water, (
-            f"Only {len(water_nodes_with_pw_edges)}/{n_water} waters have protein edges in batched data"
-        )
-
-        # Check water-water edges
         if n_water > 1:
             water_nodes_with_ww_edges = torch.unique(ww_edges[0])
             assert len(water_nodes_with_ww_edges) == n_water, (
                 f"Only {len(water_nodes_with_ww_edges)}/{n_water} waters have water-water edges in batched data"
             )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_flow.py` around lines 582 - 617, The xfail marker is currently
applied to the entire test_batched_waters_have_edges function, which hides
failures in the protein-water edge assertions that should not be expected to
fail. Remove the xfail decorator from the function and instead apply it only to
the water-water edge checking section (the assertions checking
water_nodes_with_ww_edges). This can be done by either splitting the test into
two separate test functions with xfail only on the water-water test, or by
wrapping just the water-water edge assertion block with pytest.xfail() to
preserve protein-water edge coverage while still allowing the known water-water
edge failure.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/dataset.py`:
- Around line 993-1000: The insertion code normalization is missing before
calculating residue starts, which causes misalignment with the cached ESM
embeddings. After the loop that sanitizes res_name for the sanitized_for_idx
object (which converts three-letter codes to one-letter and back), add code to
normalize the ins_code field by setting blank or non-standard insertion codes to
a consistent placeholder value (similar to how "X" is used for unknown
residues). This normalization must occur before calling
bts.get_residue_starts(sanitized_for_idx) to ensure the residue count and
protein_res_idx indices match what was computed in generate_esm_embeddings.py.

---

Outside diff comments:
In `@tests/test_flow.py`:
- Around line 582-617: The xfail marker is currently applied to the entire
test_batched_waters_have_edges function, which hides failures in the
protein-water edge assertions that should not be expected to fail. Remove the
xfail decorator from the function and instead apply it only to the water-water
edge checking section (the assertions checking water_nodes_with_ww_edges). This
can be done by either splitting the test into two separate test functions with
xfail only on the water-water test, or by wrapping just the water-water edge
assertion block with pytest.xfail() to preserve protein-water edge coverage
while still allowing the known water-water edge failure.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 65109962-a7f2-481b-b0ef-737262e6f23a

📥 Commits

Reviewing files that changed from the base of the PR and between c3b9db6 and a001f50.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (3)
  • src/dataset.py
  • src/flow.py
  • tests/test_flow.py

Comment thread src/dataset.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 4 changed files in this pull request and generated 3 comments.

Comment thread tests/test_flow.py Outdated
Comment thread tests/test_flow.py Outdated
Comment thread src/dataset.py Outdated
Copilot AI review requested due to automatic review settings June 24, 2026 22:13

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 7 changed files in this pull request and generated 1 comment.

Comment thread tests/test_flow.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants