|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +from typing import TYPE_CHECKING |
| 4 | + |
3 | 5 | from emmet.core.symmetry import CrystalSystem |
4 | | -from emmet.core.vasp.material import MaterialsDoc |
| 6 | +from emmet.core.vasp.calc_types import RunType |
| 7 | +from emmet.core.vasp.material import BlessedCalcs, MaterialsDoc |
5 | 8 | from pymatgen.core.structure import Structure |
6 | 9 |
|
7 | 10 | from mp_api.client.core.client import CoreRester, MPRestError |
8 | 11 | from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS |
9 | 12 | from mp_api.client.core.utils import validate_ids |
10 | 13 | from mp_api.client.routes.materials import MATERIALS_RESTERS |
11 | 14 |
|
| 15 | +if TYPE_CHECKING: |
| 16 | + from typing import Any |
| 17 | + |
| 18 | + from pymatgen.entries.computed_entries import ComputedStructureEntry |
| 19 | + |
12 | 20 |
|
13 | 21 | class MaterialsRester(CoreRester): |
14 | 22 | suffix = "materials/core" |
@@ -226,3 +234,75 @@ def find_structure( |
226 | 234 | return material_ids # type: ignore |
227 | 235 |
|
228 | 236 | return material_ids[0] |
| 237 | + |
| 238 | + def get_blessed_entries( |
| 239 | + self, |
| 240 | + run_type: str | RunType = RunType.r2SCAN, |
| 241 | + material_ids: list[str] | None = None, |
| 242 | + uncorrected_energy: tuple[float | None, float | None] | float | None = None, |
| 243 | + num_chunks: int | None = None, |
| 244 | + chunk_size: int = 1000, |
| 245 | + ) -> list[dict[str, str | dict | ComputedStructureEntry]]: |
| 246 | + """Get blessed calculation entries for a given material and run type. |
| 247 | +
|
| 248 | + Args: |
| 249 | + run_type (str or RunType): Calculation run type (e.g. GGA, GGA+U, r2SCAN, PBESol) |
| 250 | + material_ids (list[str]): List of material ID values |
| 251 | + uncorrected_energy (tuple[Optional[float], Optional[float]] | float): Tuple of minimum and maximum uncorrected DFT energy in eV/atom. |
| 252 | + Note that if a single value is passed, it will be used as the minimum and maximum. |
| 253 | + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. |
| 254 | + chunk_size (int): Number of data entries per chunk. |
| 255 | +
|
| 256 | + Returns: |
| 257 | + list of dict, of the form: |
| 258 | + { |
| 259 | + "material_id": MPID, |
| 260 | + "blessed_entry": ComputedStructureEntry |
| 261 | + } |
| 262 | + """ |
| 263 | + query_params: dict[str, Any] = {"run_type": str(run_type)} |
| 264 | + if material_ids: |
| 265 | + if isinstance(material_ids, str): |
| 266 | + material_ids = [material_ids] |
| 267 | + |
| 268 | + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) |
| 269 | + |
| 270 | + if uncorrected_energy: |
| 271 | + if isinstance(uncorrected_energy, float): |
| 272 | + uncorrected_energy = (uncorrected_energy, uncorrected_energy) |
| 273 | + |
| 274 | + query_params.update( |
| 275 | + { |
| 276 | + "energy_min": uncorrected_energy[0], |
| 277 | + "energy_max": uncorrected_energy[1], |
| 278 | + } |
| 279 | + ) |
| 280 | + |
| 281 | + results = self._query_resource( |
| 282 | + query_params, |
| 283 | + fields=["material_id", "entries"], |
| 284 | + suburl="blessed_tasks", |
| 285 | + parallel_param="material_ids" if material_ids else None, |
| 286 | + chunk_size=chunk_size, |
| 287 | + num_chunks=num_chunks, |
| 288 | + ) |
| 289 | + |
| 290 | + return [ |
| 291 | + { |
| 292 | + "material_id": doc["material_id"], |
| 293 | + "blessed_entry": ( |
| 294 | + next( |
| 295 | + getattr(doc["entries"], k, None) |
| 296 | + for k in BlessedCalcs.model_fields |
| 297 | + if getattr(doc["entries"], k, None) |
| 298 | + ) |
| 299 | + if self.use_document_model |
| 300 | + else next( |
| 301 | + doc["entries"][k] |
| 302 | + for k in BlessedCalcs.model_fields |
| 303 | + if doc["entries"].get(k) |
| 304 | + ) |
| 305 | + ), |
| 306 | + } |
| 307 | + for doc in (results.get("data") or []) |
| 308 | + ] |
0 commit comments