Skip to content

Commit 9a20edb

Browse files
merge conflicts
2 parents 69799bf + e1cd425 commit 9a20edb

2 files changed

Lines changed: 102 additions & 1 deletion

File tree

mp_api/client/routes/materials/materials.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
from emmet.core.symmetry import CrystalSystem
4-
from emmet.core.vasp.material import MaterialsDoc
6+
from emmet.core.vasp.calc_types import RunType
7+
from emmet.core.vasp.material import BlessedCalcs, MaterialsDoc
58
from pymatgen.core.structure import Structure
69

710
from mp_api.client.core.client import CoreRester, MPRestError
811
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
912
from mp_api.client.core.utils import validate_ids
1013
from mp_api.client.routes.materials import MATERIALS_RESTERS
1114

15+
if TYPE_CHECKING:
16+
from typing import Any
17+
18+
from pymatgen.entries.computed_entries import ComputedStructureEntry
19+
1220

1321
class MaterialsRester(CoreRester):
1422
suffix = "materials/core"
@@ -226,3 +234,75 @@ def find_structure(
226234
return material_ids # type: ignore
227235

228236
return material_ids[0]
237+
238+
def get_blessed_entries(
239+
self,
240+
run_type: str | RunType = RunType.r2SCAN,
241+
material_ids: list[str] | None = None,
242+
uncorrected_energy: tuple[float | None, float | None] | float | None = None,
243+
num_chunks: int | None = None,
244+
chunk_size: int = 1000,
245+
) -> list[dict[str, str | dict | ComputedStructureEntry]]:
246+
"""Get blessed calculation entries for a given material and run type.
247+
248+
Args:
249+
run_type (str or RunType): Calculation run type (e.g. GGA, GGA+U, r2SCAN, PBESol)
250+
material_ids (list[str]): List of material ID values
251+
uncorrected_energy (tuple[Optional[float], Optional[float]] | float): Tuple of minimum and maximum uncorrected DFT energy in eV/atom.
252+
Note that if a single value is passed, it will be used as the minimum and maximum.
253+
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
254+
chunk_size (int): Number of data entries per chunk.
255+
256+
Returns:
257+
list of dict, of the form:
258+
{
259+
"material_id": MPID,
260+
"blessed_entry": ComputedStructureEntry
261+
}
262+
"""
263+
query_params: dict[str, Any] = {"run_type": str(run_type)}
264+
if material_ids:
265+
if isinstance(material_ids, str):
266+
material_ids = [material_ids]
267+
268+
query_params.update({"material_ids": ",".join(validate_ids(material_ids))})
269+
270+
if uncorrected_energy:
271+
if isinstance(uncorrected_energy, float):
272+
uncorrected_energy = (uncorrected_energy, uncorrected_energy)
273+
274+
query_params.update(
275+
{
276+
"energy_min": uncorrected_energy[0],
277+
"energy_max": uncorrected_energy[1],
278+
}
279+
)
280+
281+
results = self._query_resource(
282+
query_params,
283+
fields=["material_id", "entries"],
284+
suburl="blessed_tasks",
285+
parallel_param="material_ids" if material_ids else None,
286+
chunk_size=chunk_size,
287+
num_chunks=num_chunks,
288+
)
289+
290+
return [
291+
{
292+
"material_id": doc["material_id"],
293+
"blessed_entry": (
294+
next(
295+
getattr(doc["entries"], k, None)
296+
for k in BlessedCalcs.model_fields
297+
if getattr(doc["entries"], k, None)
298+
)
299+
if self.use_document_model
300+
else next(
301+
doc["entries"][k]
302+
for k in BlessedCalcs.model_fields
303+
if doc["entries"].get(k)
304+
)
305+
),
306+
}
307+
for doc in (results.get("data") or [])
308+
]

tests/client/materials/test_materials.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,24 @@ def test_client(rester):
6767
custom_field_tests=custom_field_tests,
6868
sub_doc_fields=sub_doc_fields,
6969
)
70+
71+
72+
@pytest.mark.xfail(condition=True, reason="Needs new deployment.", strict=False)
73+
@pytest.mark.parametrize(
74+
"run_type, uncorrected_energy, use_document_model",
75+
[("PBE", None, True), ("r2SCAN", 1.0, False), ("GGA_U", (-50e4, 0.0), True)],
76+
)
77+
def test_blessed_entry(run_type, uncorrected_energy, use_document_model):
78+
# Si and NiO. Si has GGA and r2SCAN entries, NiO has GGA, GGA+U, and r2SCAN
79+
with MaterialsRester(use_document_model=use_document_model) as rester:
80+
blessed = rester.get_blessed_entries(
81+
run_type,
82+
material_ids=["mp-149", "mp-19009"],
83+
uncorrected_energy=uncorrected_energy,
84+
)
85+
86+
assert all(
87+
isinstance(entry, dict)
88+
and all(entry.get(k) is not None for k in ("material_id", "blessed_entry"))
89+
for entry in blessed
90+
)

0 commit comments

Comments
 (0)