|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from typing import TYPE_CHECKING |
| 4 | + |
3 | 5 | from emmet.core.settings import EmmetSettings |
4 | 6 | from emmet.core.symmetry import CrystalSystem |
5 | | -from emmet.core.vasp.material import MaterialsDoc |
| 7 | +from emmet.core.vasp.calc_types import RunType |
| 8 | +from emmet.core.vasp.material import BlessedCalcs, MaterialsDoc |
6 | 9 | from pymatgen.core.structure import Structure |
7 | 10 |
|
8 | 11 | from mp_api.client.core import BaseRester, MPRestError |
|
37 | 40 | XASRester, |
38 | 41 | ) |
39 | 42 |
|
| 43 | +if TYPE_CHECKING: |
| 44 | + from typing import Any |
| 45 | + |
| 46 | + from pymatgen.entries.computed_entries import ComputedStructureEntry |
| 47 | + |
40 | 48 | _EMMET_SETTINGS = EmmetSettings() # type: ignore |
41 | 49 |
|
42 | 50 |
|
@@ -318,3 +326,75 @@ def find_structure( |
318 | 326 | return material_ids # type: ignore |
319 | 327 |
|
320 | 328 | return material_ids[0] |
| 329 | + |
| 330 | + def get_blessed_entries( |
| 331 | + self, |
| 332 | + run_type: str | RunType = RunType.r2SCAN, |
| 333 | + material_ids: list[str] | None = None, |
| 334 | + uncorrected_energy: tuple[float | None, float | None] | float | None = None, |
| 335 | + num_chunks: int | None = None, |
| 336 | + chunk_size: int = 1000, |
| 337 | + ) -> list[dict[str, str | dict | ComputedStructureEntry]]: |
| 338 | + """Get blessed calculation entries for a given material and run type. |
| 339 | +
|
| 340 | + Args: |
| 341 | + run_type (str or RunType): Calculation run type (e.g. GGA, GGA+U, r2SCAN, PBESol) |
| 342 | + material_ids (list[str]): List of material ID values |
| 343 | + uncorrected_energy (tuple[Optional[float], Optional[float]] | float): Tuple of minimum and maximum uncorrected DFT energy in eV/atom. |
| 344 | + Note that if a single value is passed, it will be used as the minimum and maximum. |
| 345 | + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. |
| 346 | + chunk_size (int): Number of data entries per chunk. |
| 347 | +
|
| 348 | + Returns: |
| 349 | + list of dict, of the form: |
| 350 | + { |
| 351 | + "material_id": MPID, |
| 352 | + "blessed_entry": ComputedStructureEntry |
| 353 | + } |
| 354 | + """ |
| 355 | + query_params: dict[str, Any] = {"run_type": str(run_type)} |
| 356 | + if material_ids: |
| 357 | + if isinstance(material_ids, str): |
| 358 | + material_ids = [material_ids] |
| 359 | + |
| 360 | + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) |
| 361 | + |
| 362 | + if uncorrected_energy: |
| 363 | + if isinstance(uncorrected_energy, float): |
| 364 | + uncorrected_energy = (uncorrected_energy, uncorrected_energy) |
| 365 | + |
| 366 | + query_params.update( |
| 367 | + { |
| 368 | + "energy_min": uncorrected_energy[0], |
| 369 | + "energy_max": uncorrected_energy[1], |
| 370 | + } |
| 371 | + ) |
| 372 | + |
| 373 | + results = self._query_resource( |
| 374 | + query_params, |
| 375 | + fields=["material_id", "entries"], |
| 376 | + suburl="blessed_tasks", |
| 377 | + parallel_param="material_ids" if material_ids else None, |
| 378 | + chunk_size=chunk_size, |
| 379 | + num_chunks=num_chunks, |
| 380 | + ) |
| 381 | + |
| 382 | + return [ |
| 383 | + { |
| 384 | + "material_id": doc["material_id"], |
| 385 | + "blessed_entry": ( |
| 386 | + next( |
| 387 | + getattr(doc["entries"], k, None) |
| 388 | + for k in BlessedCalcs.model_fields |
| 389 | + if getattr(doc["entries"], k, None) |
| 390 | + ) |
| 391 | + if self.use_document_model |
| 392 | + else next( |
| 393 | + doc["entries"][k] |
| 394 | + for k in BlessedCalcs.model_fields |
| 395 | + if doc["entries"].get(k) |
| 396 | + ) |
| 397 | + ), |
| 398 | + } |
| 399 | + for doc in (results.get("data") or []) |
| 400 | + ] |
0 commit comments