11from __future__ import annotations
22
3+ from typing import TYPE_CHECKING
4+
35from emmet .core .settings import EmmetSettings
46from emmet .core .symmetry import CrystalSystem
57from emmet .core .vasp .calc_types import RunType
6- from emmet .core .vasp .material import MaterialsDoc
8+ from emmet .core .vasp .material import MaterialsDoc , BlessedCalcs
79from pymatgen .core .structure import Structure
810
911from mp_api .client .core import BaseRester , MPRestError
3840 XASRester ,
3941)
4042
43+ if TYPE_CHECKING :
44+ from typing import Any
45+ from pymatgen .entries .computed_entries import ComputedStructureEntry
46+
4147_EMMET_SETTINGS = EmmetSettings () # type: ignore
4248
4349
@@ -322,23 +328,30 @@ def find_structure(
322328
323329 def get_blessed_entries (
324330 self ,
325- run_type : RunType = RunType .R2SCAN ,
331+ run_type : str | RunType = RunType .r2SCAN ,
326332 material_ids : list [str ] | None = None ,
327333 uncorrected_energy : tuple [float | None , float | None ] | float | None = None ,
328334 num_chunks : int | None = None ,
329335 chunk_size : int = 1000 ,
330- ):
336+ ) -> list [ dict [ str , str | dict | ComputedStructureEntry ]] :
331337 """Get blessed calculation entries for a given material and run type.
332338
333339 Args:
334- run_type (RunType): Calculation run type (e.g. GGA, GGA+U, R2SCAN , PBESol)
340+ run_type (str or RunType): Calculation run type (e.g. GGA, GGA+U, r2SCAN , PBESol)
335341 material_ids (list[str]): List of material ID values
336342 uncorrected_energy (tuple[Optional[float], Optional[float]] | float): Tuple of minimum and maximum uncorrected DFT energy in eV/atom.
337343 Note that if a single value is passed, it will be used as the minimum and maximum.
338344 num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
339345 chunk_size (int): Number of data entries per chunk.
346+
347+ Returns:
348+ list of dict, of the form:
349+ {
350+ "material_id": MPID,
351+ "blessed_entry": ComputedStructureEntry
352+ }
340353 """
341- query_params : dict = {"run_type" : str (run_type )}
354+ query_params : dict [ str , Any ] = {"run_type" : str (run_type )}
342355 if material_ids :
343356 if isinstance (material_ids , str ):
344357 material_ids = [material_ids ]
@@ -351,17 +364,36 @@ def get_blessed_entries(
351364
352365 query_params .update (
353366 {
354- "uncorrected_energy_min" : uncorrected_energy [0 ], # type: ignore
355- "uncorrected_energy_max" : uncorrected_energy [1 ], # type: ignore
367+ "uncorrected_energy_min" : uncorrected_energy [0 ],
368+ "uncorrected_energy_max" : uncorrected_energy [1 ],
356369 }
357370 )
358371
359372 results = self ._query_resource (
360373 query_params ,
361- # fields=["material_ids ", "entries"],
374+ fields = ["material_id " , "entries" ],
362375 suburl = "blessed_tasks" ,
363376 parallel_param = "material_ids" if material_ids else None ,
364377 chunk_size = chunk_size ,
365378 num_chunks = num_chunks ,
366- )
367- return results .get ("data" )
379+ )
380+
381+ return [
382+ {
383+ "material_id" : doc ["material_id" ],
384+ "blessed_entry" : (
385+ next (
386+ getattr (doc ["entries" ],k ,None )
387+ for k in BlessedCalcs .model_fields
388+ if getattr (doc ["entries" ],k ,None )
389+ )
390+ if self .use_document_model
391+ else next (
392+ doc ["entries" ][k ]
393+ for k in BlessedCalcs .model_fields
394+ if doc ["entries" ].get (k )
395+ )
396+ ),
397+ }
398+ for doc in (results .get ("data" ) or [])
399+ ]
0 commit comments