Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
351 changes: 351 additions & 0 deletions src/sampleworks/core/rewards/structure_factor.py

Large diffs are not rendered by default.

144 changes: 68 additions & 76 deletions src/sampleworks/eval/generate_synthetic_sf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Generate synthetic structure factor amplitudes via SFcalculator-torch.

Produces an MTZ file of |Fmodel| (or |Fprotein| if not simulate solvent and
scale) for each input PDB/mmCIF structure. The MTZ file has dummy values for
SIGFP and optionally R-free flag column. Each structure can be optionally
overridden with unit cell, space group, atom selection, and occupancy.
For each input PDB/mmCIF structure, produces an MTZ file of protein structure factors
only, or when ``--simulate-solvent-and-scale`` is set, both the protein and total sets
(Fprotein/SIGFprotein/PHIFprotein and Ftotal/SIGFtotal/PHIFtotal) in the same
MTZ. The MTZ file has dummy values for the SIGF column(s) and optionally an R-free flag
column. Each structure can be optionally overridden with unit cell, space group, atom
selection, and occupancy.
"""

import argparse
Expand All @@ -15,20 +17,18 @@
from typing import Any, ClassVar

import gemmi
import numpy as np
import reciprocalspaceship as rs
import reciprocalspaceship.utils
import torch
from biotite.structure import AtomArray
from loguru import logger
from sampleworks.eval.synthetic_utils import (
atomarray_to_gemmi,
load_structure_for_synthetic_reward,
validate_occupancy_values,
)
from sampleworks.utils.atom_array_utils import BLANK_ALTLOC_IDS
from sampleworks.utils.torch_utils import try_gpu
from SFC_Torch import SFcalculator
from SFC_Torch.io import array2hier, PDBParser
from SFC_Torch.io import PDBParser


@dataclass
Expand Down Expand Up @@ -127,62 +127,41 @@ def from_dict(cls, row: dict[str, Any]) -> "BatchRowForMTZ":
)


def atomarray_to_gemmi(
atom_array: AtomArray,
unit_cell: gemmi.UnitCell | None = None,
space_group: str | None = None,
) -> gemmi.Structure:
"""Convert a biotite AtomArray to a gemmi.Structure for SFcalculator.

Anisotropic B-factors are set to zero since biotite does not store them.
Blank altloc labels are converted from biotite's '' to gemmi's '\\x00'.

Parameters
----------
atom_array
Input structure with occupancy and b_factor annotations
unit_cell
Crystallographic unit cell for the structure. If None, gemmi defaults
to (1.0, 1.0, 1.0, 90.0, 90.0, 90.0) in units of Angstroms and degrees.
space_group
Space group (in Hermann-Mauguin string format) for the structure. If
empty or invalid, SFcalculator defaults to P1.

Returns
-------
gemmi.Structure
Structure ready to be wrapped by SFC_Torch.io.PDBParser
def _amplitude_phase_columns(
sfc: SFcalculator,
label: str,
structure_factor_column: str,
miller_index_column: str,
sigma_f_scale: float,
) -> rs.DataSet:
"""Build a one-amplitude rs.DataSet with labelled F / SIGF / PHIF columns.

``sfc.prepare_dataset`` returns an amplitude column and a phase column (degrees)
for the given ``structure_factor_column`` attribute. We auto-detect those by MTZ
dtype (rather than assuming the unexposed ``FMODEL`` / ``PHIFMODEL`` names),
rename them to ``F{label}`` / ``PHIF{label}``, and synthesize a ``SIGF{label}``
column so several structure-factor sets (e.g. protein and total) can coexist in
one MTZ.
"""
n = len(atom_array)
cra_names = [
f"{atom_array.chain_id[i]}-0-{atom_array.res_name[i]}-{atom_array.atom_name[i]}"
for i in range(n)
]
# gemmi uses '\x00' for blank altloc
atom_altloc = ["\x00" if a in BLANK_ALTLOC_IDS else a for a in atom_array.altloc_id]
structure: gemmi.Structure = array2hier(
atom_pos=atom_array.coord,
atom_b_aniso=np.zeros((n, 3, 3), dtype=np.float64),
atom_b_iso=atom_array.b_factor,
atom_occ=atom_array.occupancy,
atom_name=atom_array.element,
cra_name=cra_names,
atom_altloc=atom_altloc,
res_id=atom_array.res_id,
dataset: rs.DataSet = sfc.prepare_dataset(miller_index_column, structure_factor_column)
amplitude_column = dataset.select_mtzdtype(rs.StructureFactorAmplitudeDtype()).columns[0]
phase_column = dataset.select_mtzdtype(rs.PhaseDtype()).columns[0]
logger.debug(
f"Auto-detected amplitude column: {amplitude_column}, "
f"phase column: {phase_column} for {label}"
)
if unit_cell is not None:
structure.cell = unit_cell
if space_group is not None:
structure.spacegroup_hm = space_group
return structure
f_col, phi_col, sig_col = f"F{label}", f"PHIF{label}", f"SIGF{label}"
dataset = dataset.rename(columns={amplitude_column: f_col, phase_column: phi_col})
dataset[sig_col] = (dataset[f_col] * sigma_f_scale).astype(rs.StandardDeviationDtype())
return dataset[[f_col, sig_col, phi_col]]


def process_amplitudes_to_dataset(
sfc: SFcalculator,
structure_factor_columns: dict[str, str],
test_fraction: float = 0.05,
seed: int | None = None,
miller_index_column: str = "Hasu_array",
structure_factor_column: str = "Ftotal_asu",
ccp4_convention: bool = False,
sigma_f_scale: float = 0.2,
output_path: Path | None = None,
Expand All @@ -193,14 +172,20 @@ def process_amplitudes_to_dataset(
----------
sfc: SFcalculator
SFcalculator instance
structure_factor_columns: dict[str, str]
Mapping of ``label -> SFcalculator attribute``. One structure-factor set
(``F{label}``/``SIGF{label}``/``PHIF{label}``) is emitted per entry, and
multiple entries are merged into one MTZ sharing the same HKL list
(``miller_index_column``) and a single R-free column, e.g.
``{"protein": "Fprotein_asu", "total": "Ftotal_asu"}`` produces
``Fprotein``/``SIGFprotein``/``PHIFprotein`` and
``Ftotal``/``SIGFtotal``/``PHIFtotal``.
test_fraction: float
Fraction of reflections to mark as R-free test set (0 disables)
seed: int | None
Optional seed for reproducible R-free flag assignment
miller_index_column: str
Attribute name in SFcalculator for hkl indices
structure_factor_column: str
Attribute name in SFcalculator for structure factors
ccp4_convention: bool
If True, use CCP4 convention for R-free flag assignment. Default
is False, which uses Phenix convention (1 = test, 0 = working).
Expand All @@ -214,18 +199,18 @@ def process_amplitudes_to_dataset(
Returns
-------
rs.DataSet
Dataset with structure factor amplitudes, fake sigma column, and optionally
R-free flags.
Dataset with structure factor amplitudes, dummy sigma column(s), phases,
and optionally R-free flags.
"""
dataset: rs.DataSet = sfc.prepare_dataset(miller_index_column, structure_factor_column)
# assumes the first detected column of dtype F is the structure factor amplitude column
# avoids hardcoding unexposed column name "FMODEL" from sfc.prepare_dataset().
structure_factor_amplitude_column = dataset.select_mtzdtype(
rs.StructureFactorAmplitudeDtype()
).columns[0]
sigma_f_column = f"SIG{structure_factor_amplitude_column}"
dataset[sigma_f_column] = dataset[structure_factor_amplitude_column] * sigma_f_scale
dataset[sigma_f_column] = dataset[sigma_f_column].astype(rs.StandardDeviationDtype())
if not structure_factor_columns:
raise ValueError("structure_factor_columns must contain at least one entry.")
column_items = iter(structure_factor_columns.items())
label, attribute = next(column_items)
dataset = _amplitude_phase_columns(sfc, label, attribute, miller_index_column, sigma_f_scale)
for label, attribute in column_items:
ds = _amplitude_phase_columns(sfc, label, attribute, miller_index_column, sigma_f_scale)
for col in ds.columns:
dataset[col] = ds[col]
if test_fraction > 0:
dataset = rs.utils.add_rfree(
dataset,
Expand Down Expand Up @@ -288,8 +273,9 @@ def _process_single_row(
If True, remove ligand molecules (non-water heteroatoms) before computing structure
factors. Default is False.
simulate_solvent_and_scale
If True, compute bulk solvent and scale factors for Ftotal instead of Fprotein.
Default is False.
If True, compute bulk solvent and scale factors and write a single MTZ containing
both the protein and total structure factor sets. If False (default), only the
protein set is written. One set contains F{label}/SIGF{label}/PHIF{label}.
save_structure
If True, save the processed structure (after selection and occupancy assignment)
as mmCIF to output_dir. Unit cell and space group are preserved. Default is False.
Expand Down Expand Up @@ -354,15 +340,17 @@ def _process_single_row(
f"n_atoms: {len(sfc.atom_pos_orth)}"
)
sfc.calc_fprotein()
structure_factor_columns = {"protein": "Fprotein_asu"}
if simulate_solvent_and_scale:
sfc.inspect_data()
sfc.calc_fsolvent()
sfc.init_scales(requires_grad=False)
sfc.calc_ftotal()
F_attribute = "Ftotal_asu"
else:
F_attribute = "Fprotein_asu"
logger.debug(f"Computed {F_attribute} for {row.filename} on {device}")
structure_factor_columns.update({"total": "Ftotal_asu"})
logger.debug(
f"Computed {'Fprotein + Ftotal' if simulate_solvent_and_scale else 'Fprotein'} "
f"for {row.filename} on {device}"
)
except Exception as e:
logger.error(
f"Failed to compute for {row.filename} ({type(e).__name__}): {e}\n"
Expand All @@ -375,7 +363,7 @@ def _process_single_row(
try:
process_amplitudes_to_dataset(
sfc,
structure_factor_column=F_attribute,
structure_factor_columns=structure_factor_columns,
test_fraction=test_fraction,
seed=seed,
output_path=output_path,
Expand Down Expand Up @@ -552,7 +540,11 @@ def parse_args() -> argparse.Namespace:
sf_group.add_argument(
"--simulate-solvent-and-scale",
action="store_true",
help="Compute bulk solvent and overall scale factors (outputs Ftotal instead of Fprotein)",
help=(
"Compute bulk solvent and overall scale factors and write both protein and "
"total structure factor in one MTZ. Without this flag, protein only. Each "
"set contains F\\{label\\}/SIGF\\{label\\}/PHIF\\{label\\}."
),
)
sf_group.add_argument(
"--remove-hydrogens",
Expand Down
78 changes: 78 additions & 0 deletions src/sampleworks/eval/synthetic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import traceback
from pathlib import Path

import gemmi
import numpy as np
from atomworks.io.transforms.atom_array import remove_waters
from biotite.structure import AtomArray
from loguru import logger
from sampleworks.eval.structure_utils import apply_selection
from sampleworks.utils.atom_array_utils import (
AltlocInfo,
BLANK_ALTLOC_IDS,
detect_altlocs,
keep_amino_acids,
keep_polymer,
Expand Down Expand Up @@ -194,3 +197,78 @@ def load_structure_for_synthetic_reward(
raise ValueError(f"Invalid occupancy mode '{occupancy_mode}'")

return atom_array


def atomarray_to_gemmi(
atom_array: AtomArray,
unit_cell: gemmi.UnitCell | None = None,
space_group: str | None = None,
) -> gemmi.Structure:
"""Convert a biotite AtomArray to a gemmi.Structure for SFcalculator.

Anisotropic B-factors are set to zero since biotite does not store them.
Blank altloc labels are converted from biotite's '' to gemmi's '\\x00'. If
the atom array has no ``altloc_id`` annotation (e.g. arrays reconstructed by
a model wrapper), all altlocs default to blank.

Parameters
----------
atom_array
Input structure with occupancy and b_factor annotations
unit_cell
Crystallographic unit cell for the structure. If None, gemmi defaults
to (1.0, 1.0, 1.0, 90.0, 90.0, 90.0) in units of Angstroms and degrees.
space_group
Space group (in Hermann-Mauguin string format) for the structure. If
empty or invalid, SFcalculator defaults to P1.

Returns
-------
gemmi.Structure
Structure ready to be wrapped by SFC_Torch.io.PDBParser
"""
# Lazy import so importing this module does not require SFC_Torch on paths
# that don't need it (e.g. synthetic density generation).
from SFC_Torch.io import array2hier

n = len(atom_array)
cra_names = [
f"{atom_array.chain_id[i]}-0-{atom_array.res_name[i]}-{atom_array.atom_name[i]}"
for i in range(n)
]
# altloc_id is not a mandatory biotite annotation; default to blank when absent.
# gemmi uses '\x00' for blank altloc
if "altloc_id" in atom_array.get_annotation_categories():
atom_altloc = [
"\x00" if a in BLANK_ALTLOC_IDS else a
for a in atom_array.altloc_id
]
else:
atom_altloc = ["\x00"] * n
structure: gemmi.Structure = array2hier(
atom_pos=atom_array.coord,
atom_b_aniso=np.zeros((n, 3, 3), dtype=np.float64),
atom_b_iso=atom_array.b_factor,
atom_occ=atom_array.occupancy,
atom_name=atom_array.element,
cra_name=cra_names,
atom_altloc=atom_altloc,
res_id=atom_array.res_id,
)
# array2hier names the single model "SFC" and its setup_entities() assigns auto-generated
# subchain ids (label_asym_id, e.g. "Axp"). Both corrupt a written-out mmCIF: the
# non-integer model name breaks mmCIF parsers' pdbx_PDB_model_num (biotite/atomworks read
# it as int), and the multi-char label_asym_id is re-read as the chain id (then rejected by
# SFcalculator's PDB-header step, which needs a <=1-char chain). Normalize both — a valid
# numeric model id and label_asym_id == the chain name — so saved structures
# (generate_synthetic_sf --save-structure) round-trip.
for model_idx, model in enumerate(structure):
model.name = str(model_idx + 1)
for chain in model:
for residue in chain:
residue.subchain = chain.name
if unit_cell is not None:
structure.cell = unit_cell
if space_group is not None:
structure.spacegroup_hm = space_group
return structure
Loading
Loading