Skip to content

Commit e1cd425

Browse files
[WIP] Add convenience method for obtaining blessed calculation for different functionals (#880)
2 parents a796979 + 9fb7e52 commit e1cd425

2 files changed

Lines changed: 102 additions & 1 deletion

File tree

mp_api/client/routes/materials/materials.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
from emmet.core.settings import EmmetSettings
46
from emmet.core.symmetry import CrystalSystem
5-
from emmet.core.vasp.material import MaterialsDoc
7+
from emmet.core.vasp.calc_types import RunType
8+
from emmet.core.vasp.material import BlessedCalcs, MaterialsDoc
69
from pymatgen.core.structure import Structure
710

811
from mp_api.client.core import BaseRester, MPRestError
@@ -37,6 +40,11 @@
3740
XASRester,
3841
)
3942

43+
if TYPE_CHECKING:
44+
from typing import Any
45+
46+
from pymatgen.entries.computed_entries import ComputedStructureEntry
47+
4048
_EMMET_SETTINGS = EmmetSettings() # type: ignore
4149

4250

@@ -318,3 +326,75 @@ def find_structure(
318326
return material_ids # type: ignore
319327

320328
return material_ids[0]
329+
330+
def get_blessed_entries(
331+
self,
332+
run_type: str | RunType = RunType.r2SCAN,
333+
material_ids: list[str] | None = None,
334+
uncorrected_energy: tuple[float | None, float | None] | float | None = None,
335+
num_chunks: int | None = None,
336+
chunk_size: int = 1000,
337+
) -> list[dict[str, str | dict | ComputedStructureEntry]]:
338+
"""Get blessed calculation entries for a given material and run type.
339+
340+
Args:
341+
run_type (str or RunType): Calculation run type (e.g. GGA, GGA+U, r2SCAN, PBESol)
342+
material_ids (list[str]): List of material ID values
343+
uncorrected_energy (tuple[Optional[float], Optional[float]] | float): Tuple of minimum and maximum uncorrected DFT energy in eV/atom.
344+
Note that if a single value is passed, it will be used as the minimum and maximum.
345+
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
346+
chunk_size (int): Number of data entries per chunk.
347+
348+
Returns:
349+
list of dict, of the form:
350+
{
351+
"material_id": MPID,
352+
"blessed_entry": ComputedStructureEntry
353+
}
354+
"""
355+
query_params: dict[str, Any] = {"run_type": str(run_type)}
356+
if material_ids:
357+
if isinstance(material_ids, str):
358+
material_ids = [material_ids]
359+
360+
query_params.update({"material_ids": ",".join(validate_ids(material_ids))})
361+
362+
if uncorrected_energy:
363+
if isinstance(uncorrected_energy, float):
364+
uncorrected_energy = (uncorrected_energy, uncorrected_energy)
365+
366+
query_params.update(
367+
{
368+
"energy_min": uncorrected_energy[0],
369+
"energy_max": uncorrected_energy[1],
370+
}
371+
)
372+
373+
results = self._query_resource(
374+
query_params,
375+
fields=["material_id", "entries"],
376+
suburl="blessed_tasks",
377+
parallel_param="material_ids" if material_ids else None,
378+
chunk_size=chunk_size,
379+
num_chunks=num_chunks,
380+
)
381+
382+
return [
383+
{
384+
"material_id": doc["material_id"],
385+
"blessed_entry": (
386+
next(
387+
getattr(doc["entries"], k, None)
388+
for k in BlessedCalcs.model_fields
389+
if getattr(doc["entries"], k, None)
390+
)
391+
if self.use_document_model
392+
else next(
393+
doc["entries"][k]
394+
for k in BlessedCalcs.model_fields
395+
if doc["entries"].get(k)
396+
)
397+
),
398+
}
399+
for doc in (results.get("data") or [])
400+
]

tests/client/materials/test_materials.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,24 @@ def test_client(rester):
6262
custom_field_tests=custom_field_tests,
6363
sub_doc_fields=sub_doc_fields,
6464
)
65+
66+
67+
@pytest.mark.xfail(condition=True, reason="Needs new deployment.", strict=False)
68+
@pytest.mark.parametrize(
69+
"run_type, uncorrected_energy, use_document_model",
70+
[("PBE", None, True), ("r2SCAN", 1.0, False), ("GGA_U", (-50e4, 0.0), True)],
71+
)
72+
def test_blessed_entry(run_type, uncorrected_energy, use_document_model):
73+
# Si and NiO. Si has GGA and r2SCAN entries, NiO has GGA, GGA+U, and r2SCAN
74+
with MaterialsRester(use_document_model=use_document_model) as rester:
75+
blessed = rester.get_blessed_entries(
76+
run_type,
77+
material_ids=["mp-149", "mp-19009"],
78+
uncorrected_energy=uncorrected_energy,
79+
)
80+
81+
assert all(
82+
isinstance(entry, dict)
83+
and all(entry.get(k) is not None for k in ("material_id", "blessed_entry"))
84+
for entry in blessed
85+
)

0 commit comments

Comments
 (0)