Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 40 additions & 26 deletions src/sampleworks/eval/generate_synthetic_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from typing import ClassVar

import torch
from joblib import delayed, Parallel
from loguru import logger
from sampleworks.core.forward_models.xray.real_space_density import XMap_torch
from sampleworks.eval.synthetic_utils import load_structure_for_synthetic_reward
from sampleworks.eval.synthetic_utils import (
load_structure_for_synthetic_reward,
validate_occupancy_values,
)
from sampleworks.utils.atom_array_utils import save_structure_to_cif
from sampleworks.utils.density_utils import compute_density_from_atomarray
from sampleworks.utils.torch_utils import try_gpu
Expand All @@ -26,29 +28,34 @@ class BatchRow:
Path to the structure file (relative to base_dir)
selection
Optional atom selection string in pyMOL-like syntax (e.g., 'chain A and resi 10-50')
occ_values
occupancy_values
Custom occupancy values for altlocs, must be in range [0.0, 1.0]
mapfile
Optional custom output filename for the density map
"""

VALID_EXTENSIONS: ClassVar[frozenset[str]] = frozenset({".pdb", ".cif", ".mmcif", ".ent"})
VALID_EXTENSIONS: ClassVar[frozenset[str]] = frozenset({".cif", ".mmcif"})
LEGACY_EXTENSIONS: ClassVar[frozenset[str]] = frozenset({".pdb", ".ent"})

filename: str
selection: str | None = None
occ_values: list[float] = field(default_factory=list)
occupancy_values: list[float] = field(default_factory=list)
mapfile: str | None = None

def __post_init__(self) -> None:
ext = Path(self.filename).suffix.lower()
if ext not in self.VALID_EXTENSIONS:
all_supported = self.VALID_EXTENSIONS | self.LEGACY_EXTENSIONS
if ext not in all_supported:
raise ValueError(
f"Invalid file extension '{ext}' for '{self.filename}'. "
f"Expected one of: {', '.join(sorted(self.VALID_EXTENSIONS))}"
f"Expected one of: {', '.join(sorted(all_supported))}"
)
for i, v in enumerate(self.occ_values):
if not 0.0 <= v <= 1.0:
raise ValueError(f"Occupancy value {v} at index {i} is out of range [0.0, 1.0]")
if ext in self.LEGACY_EXTENSIONS:
logger.warning(
f"'{ext}' is a legacy PDB format and support may be removed in a future version. "
"Prefer .cif or .mmcif (mmCIF format)."
)
validate_occupancy_values(self.occupancy_values)

@classmethod
def from_dict(cls, row: dict[str, str]) -> "BatchRow":
Expand All @@ -58,7 +65,7 @@ def from_dict(cls, row: dict[str, str]) -> "BatchRow":
----------
row
Dictionary with keys 'filename' (required), and optionally
'selection', 'occ_values' (colon-separated), and 'mapfile'
'selection', 'occupancy_values' (colon-separated), and 'mapfile'

Returns
-------
Expand All @@ -75,14 +82,15 @@ def from_dict(cls, row: dict[str, str]) -> "BatchRow":
if "filename" not in row:
raise KeyError("CSV is missing required 'filename' column")

occ_values: list[float] = []
if row.get("occ_values"):
occ_values = [float(v.strip()) for v in row["occ_values"].split(":")]
occupancy_values: list[float] = []
occupancy_values_csv = row.get("occupancy_values") or row.get("occ_values")
if occupancy_values_csv:
occupancy_values = [float(v.strip()) for v in occupancy_values_csv.split(":")]

return cls(
filename=row["filename"],
selection=row.get("selection") or None,
occ_values=occ_values,
occupancy_values=occupancy_values,
mapfile=row.get("mapfile") or None,
)

Expand Down Expand Up @@ -111,7 +119,7 @@ def load_batch_csv(csv_path: Path) -> list[BatchRow]:
----------
csv_path
Path to CSV file with columns: filename (required), selection (optional),
occ_values (optional), mapfile (optional)
occupancy_values (optional), mapfile (optional)

Returns
-------
Expand All @@ -135,7 +143,7 @@ def load_batch_csv(csv_path: Path) -> list[BatchRow]:

def _process_single_row(
row: BatchRow,
occ_mode: str,
occupancy_mode: str,
base_dir: Path,
output_dir: Path,
resolution: float,
Expand All @@ -152,7 +160,7 @@ def _process_single_row(
----------
row
BatchRow containing structure information
occ_mode
occupancy_mode
Occupancy assignment mode: 'default', 'uniform', or 'custom'
base_dir
Base directory for resolving relative structure file paths
Expand Down Expand Up @@ -180,8 +188,8 @@ def _process_single_row(
structure_path = base_dir / row.filename
atom_array = load_structure_for_synthetic_reward(
structure_path,
occupancy_mode=occ_mode,
occupancy_values=row.occ_values,
occupancy_mode=occupancy_mode,
occupancy_values=row.occupancy_values,
strip_hydrogens=strip_hydrogens,
strip_waters=strip_waters,
strip_ligands=strip_ligands,
Expand Down Expand Up @@ -237,7 +245,7 @@ def process_batch(
base_dir: Path,
output_dir: Path,
resolution: float,
occ_mode: str,
occupancy_mode: str,
em_mode: bool,
device: torch.device,
n_jobs: int = -1,
Expand Down Expand Up @@ -273,6 +281,8 @@ def process_batch(
save_structure
If True, save the processed structure to a CIF file in the input directory.
"""
from joblib import delayed, Parallel

rows = load_batch_csv(csv_path)
logger.info(f"Processing {len(rows)} structures from {csv_path} using {n_jobs} jobs")

Expand All @@ -282,7 +292,7 @@ def process_batch(
Parallel(n_jobs=n_jobs, backend="loky")(
delayed(_process_single_row)(
row=row,
occ_mode=occ_mode,
occupancy_mode=occupancy_mode,
base_dir=base_dir,
output_dir=output_dir,
resolution=resolution,
Expand Down Expand Up @@ -323,13 +333,17 @@ def parse_args() -> argparse.Namespace:

occ_group = parser.add_argument_group("Occupancy Options")
occ_group.add_argument(
"--occupancy-mode",
"--occ-mode",
dest="occupancy_mode",
choices=["default", "uniform", "custom"],
default="default",
help="Occupancy assignment mode",
)
occ_group.add_argument(
"--occupancy-values",
"--occ-values",
dest="occupancy_values",
type=str,
help="Colon-separated occupancy values for custom mode (e.g., '0.3:0.7')",
)
Expand Down Expand Up @@ -389,7 +403,7 @@ def main() -> None:
base_dir=args.base_dir,
output_dir=args.output_dir,
resolution=args.resolution,
occ_mode=args.occ_mode,
occupancy_mode=args.occupancy_mode,
em_mode=args.em_mode,
device=device,
n_jobs=args.n_jobs,
Expand All @@ -402,14 +416,14 @@ def main() -> None:
row = BatchRow(
filename=str(args.structure),
selection=args.selection,
occ_values=[float(v.strip()) for v in args.occ_values.split(":")]
if args.occ_values
occupancy_values=[float(v.strip()) for v in args.occupancy_values.split(":")]
if args.occupancy_values
else [],
mapfile=args.output.name if args.output else None,
)
_process_single_row(
row=row,
occ_mode=args.occ_mode,
occupancy_mode=args.occupancy_mode,
base_dir=args.structure.parent,
output_dir=args.output.parent if args.output else Path("."),
resolution=args.resolution,
Expand Down
33 changes: 33 additions & 0 deletions tests/eval/test_generate_synthetic_density.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Tests for synthetic density batch-row argument handling."""

import pytest
from sampleworks.eval.generate_synthetic_density import BatchRow


def test_batch_row_accepts_occupancy_values_column() -> None:
"""Batch CSV parsing uses the canonical occupancy_values column name."""
row = BatchRow.from_dict(
{"filename": "input.cif", "selection": "chain A", "occupancy_values": "0.25:0.75"}
)

assert row.occupancy_values == [0.25, 0.75]
assert row.selection == "chain A"


def test_batch_row_accepts_legacy_occ_values_column() -> None:
"""The old occ_values column remains accepted for existing batch CSVs."""
row = BatchRow.from_dict({"filename": "input.cif", "occ_values": "0.4:0.6"})

assert row.occupancy_values == [0.4, 0.6]


def test_batch_row_rejects_occupancy_values_that_do_not_sum_to_one() -> None:
"""Density generation now uses the shared occupancy-value validation helper."""
with pytest.raises(ValueError, match="must sum to 1.0"):
BatchRow(filename="input.cif", occupancy_values=[0.2, 0.3])


def test_batch_row_rejects_unsupported_extension() -> None:
"""Only mmCIF and legacy PDB-like structure extensions are supported."""
with pytest.raises(ValueError, match="Invalid file extension"):
BatchRow(filename="input.txt")
Loading