Skip to content

Commit b3340d3

Browse files
fix blessed entries rester
1 parent 317ffe4 commit b3340d3

1 file changed

Lines changed: 42 additions & 10 deletions

File tree

mp_api/client/routes/materials/materials.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
from emmet.core.settings import EmmetSettings
46
from emmet.core.symmetry import CrystalSystem
57
from emmet.core.vasp.calc_types import RunType
6-
from emmet.core.vasp.material import MaterialsDoc
8+
from emmet.core.vasp.material import MaterialsDoc, BlessedCalcs
79
from pymatgen.core.structure import Structure
810

911
from mp_api.client.core import BaseRester, MPRestError
@@ -38,6 +40,10 @@
3840
XASRester,
3941
)
4042

43+
if TYPE_CHECKING:
44+
from typing import Any
45+
from pymatgen.entries.computed_entries import ComputedStructureEntry
46+
4147
_EMMET_SETTINGS = EmmetSettings() # type: ignore
4248

4349

@@ -322,23 +328,30 @@ def find_structure(
322328

323329
def get_blessed_entries(
324330
self,
325-
run_type: RunType = RunType.R2SCAN,
331+
run_type: str | RunType = RunType.r2SCAN,
326332
material_ids: list[str] | None = None,
327333
uncorrected_energy: tuple[float | None, float | None] | float | None = None,
328334
num_chunks: int | None = None,
329335
chunk_size: int = 1000,
330-
):
336+
) -> list[dict[str, str | dict | ComputedStructureEntry]]:
331337
"""Get blessed calculation entries for a given material and run type.
332338
333339
Args:
334-
run_type (RunType): Calculation run type (e.g. GGA, GGA+U, R2SCAN, PBESol)
340+
run_type (str or RunType): Calculation run type (e.g. GGA, GGA+U, r2SCAN, PBESol)
335341
material_ids (list[str]): List of material ID values
336342
uncorrected_energy (tuple[Optional[float], Optional[float]] | float): Tuple of minimum and maximum uncorrected DFT energy in eV/atom.
337343
Note that if a single value is passed, it will be used as the minimum and maximum.
338344
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
339345
chunk_size (int): Number of data entries per chunk.
346+
347+
Returns:
348+
list of dict, of the form:
349+
{
350+
"material_id": MPID,
351+
"blessed_entry": ComputedStructureEntry
352+
}
340353
"""
341-
query_params: dict = {"run_type": str(run_type)}
354+
query_params: dict[str, Any] = {"run_type": str(run_type)}
342355
if material_ids:
343356
if isinstance(material_ids, str):
344357
material_ids = [material_ids]
@@ -351,17 +364,36 @@ def get_blessed_entries(
351364

352365
query_params.update(
353366
{
354-
"uncorrected_energy_min": uncorrected_energy[0], # type: ignore
355-
"uncorrected_energy_max": uncorrected_energy[1], # type: ignore
367+
"uncorrected_energy_min": uncorrected_energy[0],
368+
"uncorrected_energy_max": uncorrected_energy[1],
356369
}
357370
)
358371

359372
results = self._query_resource(
360373
query_params,
361-
# fields=["material_ids", "entries"],
374+
fields=["material_id", "entries"],
362375
suburl="blessed_tasks",
363376
parallel_param="material_ids" if material_ids else None,
364377
chunk_size=chunk_size,
365378
num_chunks=num_chunks,
366-
)
367-
return results.get("data")
379+
)
380+
381+
return [
382+
{
383+
"material_id": doc["material_id"],
384+
"blessed_entry": (
385+
next(
386+
getattr(doc["entries"],k,None)
387+
for k in BlessedCalcs.model_fields
388+
if getattr(doc["entries"],k,None)
389+
)
390+
if self.use_document_model
391+
else next(
392+
doc["entries"][k]
393+
for k in BlessedCalcs.model_fields
394+
if doc["entries"].get(k)
395+
)
396+
),
397+
}
398+
for doc in (results.get("data") or [])
399+
]

0 commit comments

Comments
 (0)