Skip to content

Commit 69e3715

Browse files
Compatibility with emmet-core 0.86.0rc1 (#1021)
1 parent 254c7d0 commit 69e3715

9 files changed

Lines changed: 107 additions & 154 deletions

File tree

mp_api/client/core/client.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
from importlib.metadata import PackageNotFoundError, version
1919
from json import JSONDecodeError
2020
from math import ceil
21-
from typing import (
22-
TYPE_CHECKING,
23-
ForwardRef,
24-
Optional,
25-
get_args,
26-
)
21+
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
2722
from urllib.parse import quote, urljoin
2823

2924
import requests
@@ -64,6 +59,23 @@
6459
SETTINGS = MAPIClientSettings() # type: ignore
6560

6661

62+
class _DictLikeAccess(BaseModel):
63+
"""Define a pydantic mix-in which permits dict-like access to model fields."""
64+
65+
def __getitem__(self, item: str) -> Any:
66+
"""Return `item` if a valid model field, otherwise raise an exception."""
67+
if item in self.__class__.model_fields:
68+
return getattr(self, item)
69+
raise AttributeError(f"{self.__class__.__name__} has no model field `{item}`.")
70+
71+
def get(self, item: str, default: Any = None) -> Any:
72+
"""Return a model field `item`, or `default` if it doesn't exist."""
73+
try:
74+
return self.__getitem__(item)
75+
except AttributeError:
76+
return default
77+
78+
6779
class BaseRester:
6880
"""Base client class with core stubs."""
6981

@@ -427,13 +439,9 @@ def _query_resource(
427439
if use_document_model is None:
428440
use_document_model = self.use_document_model
429441

430-
if timeout is None:
431-
timeout = self.timeout
442+
timeout = self.timeout if timeout is None else timeout
432443

433-
if criteria:
434-
criteria = {k: v for k, v in criteria.items() if v is not None}
435-
else:
436-
criteria = {}
444+
criteria = {k: v for k, v in (criteria or {}).items() if v is not None}
437445

438446
# Query s3 if no query is passed and all documents are asked for
439447
# TODO also skip fields set to same as their default
@@ -1080,6 +1088,7 @@ def _generate_returned_model(
10801088
# TODO fields_not_requested is not the same as unset_fields
10811089
# i.e. field could be requested but not available in the raw doc
10821090
fields_not_requested=(list[str], unset_fields),
1091+
__base__=_DictLikeAccess,
10831092
__doc__=".".join(
10841093
[
10851094
getattr(self.document_model, k, "")

mp_api/client/mprester.py

Lines changed: 27 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import os
55
import warnings
6+
from collections import defaultdict
67
from functools import cache, lru_cache
78
from typing import TYPE_CHECKING
89

@@ -424,20 +425,15 @@ def get_task_ids_associated_with_material_id(
424425
if not tasks:
425426
return []
426427

427-
calculations = (
428-
tasks[0].calc_types # type: ignore
429-
if self.use_document_model
430-
else tasks[0]["calc_types"] # type: ignore
431-
)
428+
calculations = tasks[0]["calc_types"]
432429

433430
if calc_types:
434431
return [
435432
task
436433
for task, calc_type in calculations.items()
437434
if calc_type in calc_types
438435
]
439-
else:
440-
return list(calculations.keys())
436+
return list(calculations.keys())
441437

442438
def get_structure_by_material_id(
443439
self, material_id: str, final: bool = True, conventional_unit_cell: bool = False
@@ -539,11 +535,7 @@ def get_material_id_references(self, material_id: str) -> list[str]:
539535
List of BibTeX references ([str])
540536
"""
541537
docs = self.materials.provenance.search(material_ids=material_id)
542-
543-
if not docs:
544-
return []
545-
546-
return docs[0].references if self.use_document_model else docs[0]["references"] # type: ignore
538+
return docs[0]["references"] if docs else []
547539

548540
def get_material_ids(
549541
self,
@@ -558,17 +550,16 @@ def get_material_ids(
558550
Returns:
559551
List of all materials ids ([MPID])
560552
"""
553+
inp_k = "formula"
561554
if isinstance(chemsys_formula, list) or (
562555
isinstance(chemsys_formula, str) and "-" in chemsys_formula
563556
):
564-
input_params = {"chemsys": chemsys_formula}
565-
else:
566-
input_params = {"formula": chemsys_formula}
557+
inp_k = "chemsys"
567558

568559
return sorted(
569-
doc.material_id if self.use_document_model else doc["material_id"] # type: ignore
560+
doc["material_id"]
570561
for doc in self.materials.search(
571-
**input_params, # type: ignore
562+
**{inp_k: chemsys_formula},
572563
all_fields=False,
573564
fields=["material_id"],
574565
)
@@ -601,10 +592,8 @@ def get_structures(
601592
all_fields=False,
602593
fields=["structure"],
603594
)
604-
if not self.use_document_model:
605-
return [doc["structure"] for doc in docs] # type: ignore
606595

607-
return [doc.structure for doc in docs] # type: ignore
596+
return [doc["structure"] for doc in docs]
608597
else:
609598
structures = []
610599

@@ -613,12 +602,7 @@ def get_structures(
613602
all_fields=False,
614603
fields=["initial_structures"],
615604
):
616-
initial_structures = (
617-
doc.initial_structures # type: ignore
618-
if self.use_document_model
619-
else doc["initial_structures"] # type: ignore
620-
)
621-
structures.extend(initial_structures)
605+
structures.extend(doc["initial_structures"])
622606

623607
return structures
624608

@@ -723,7 +707,7 @@ def get_entries(
723707
if additional_criteria:
724708
input_params = {**input_params, **additional_criteria}
725709

726-
entries = []
710+
entries: set[ComputedStructureEntry] = set()
727711

728712
fields = (
729713
["entries", "thermo_type"]
@@ -738,24 +722,17 @@ def get_entries(
738722
)
739723

740724
for doc in docs:
741-
entry_list = (
742-
doc.entries.values() # type: ignore
743-
if self.use_document_model
744-
else doc["entries"].values() # type: ignore
745-
)
725+
entry_list = doc["entries"].values()
746726
for entry in entry_list:
747-
entry_dict: dict = entry.as_dict() if self.monty_decode else entry # type: ignore
727+
entry_dict: dict = entry.as_dict() if hasattr(entry, "as_dict") else entry # type: ignore
748728
if not compatible_only:
749729
entry_dict["correction"] = 0.0
750730
entry_dict["energy_adjustments"] = []
751731

752732
if property_data:
753-
for property in property_data:
754-
entry_dict["data"][property] = (
755-
doc.model_dump()[property] # type: ignore
756-
if self.use_document_model
757-
else doc[property] # type: ignore
758-
)
733+
entry_dict["data"] = {
734+
property: doc[property] for property in property_data
735+
}
759736

760737
if conventional_unit_cell:
761738
entry_struct = Structure.from_dict(entry_dict["structure"])
@@ -776,15 +753,10 @@ def get_entries(
776753
if "n_atoms" in correction:
777754
correction["n_atoms"] *= site_ratio
778755

779-
entry = (
780-
ComputedStructureEntry.from_dict(entry_dict)
781-
if self.monty_decode
782-
else entry_dict
783-
)
756+
# Need to store object to permit de-duplication
757+
entries.add(ComputedStructureEntry.from_dict(entry_dict))
784758

785-
entries.append(entry)
786-
787-
return entries
759+
return [e if self.monty_decode else e.as_dict() for e in entries]
788760

789761
def get_pourbaix_entries(
790762
self,
@@ -1315,9 +1287,7 @@ def get_wulff_shape(self, material_id: str):
13151287
if not doc:
13161288
return None
13171289

1318-
surfaces: list = (
1319-
doc[0].surfaces if self.use_document_model else doc[0]["surfaces"] # type: ignore
1320-
)
1290+
surfaces: list = doc[0]["surfaces"]
13211291

13221292
lattice = (
13231293
SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice
@@ -1387,17 +1357,8 @@ def get_charge_density_from_material_id(
13871357
if len(results) == 0:
13881358
return None
13891359

1390-
latest_doc = max( # type: ignore
1391-
results,
1392-
key=lambda x: (
1393-
x.last_updated # type: ignore
1394-
if self.use_document_model
1395-
else x["last_updated"]
1396-
), # type: ignore
1397-
)
1398-
task_id = (
1399-
latest_doc.task_id if self.use_document_model else latest_doc["task_id"]
1400-
)
1360+
latest_doc = max(results, key=lambda x: x["last_updated"])
1361+
task_id = latest_doc["task_id"]
14011362
return self.get_charge_density_from_task_id(task_id, inc_task_doc)
14021363

14031364
def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
@@ -1419,20 +1380,17 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
14191380
else []
14201381
)
14211382

1422-
meta = {}
1383+
meta = defaultdict(list)
14231384
for doc in self.materials.search( # type: ignore
14241385
task_ids=material_ids,
14251386
fields=["calc_types", "deprecated_tasks", "material_id"],
14261387
):
1427-
doc_dict: dict = doc.model_dump() if self.use_document_model else doc # type: ignore
1428-
for task_id, calc_type in doc_dict["calc_types"].items():
1388+
for task_id, calc_type in doc["calc_types"].items():
14291389
if calc_types and calc_type not in calc_types:
14301390
continue
1431-
mp_id = doc_dict["material_id"]
1432-
if meta.get(mp_id) is None:
1433-
meta[mp_id] = [{"task_id": task_id, "calc_type": calc_type}]
1434-
else:
1435-
meta[mp_id].append({"task_id": task_id, "calc_type": calc_type})
1391+
mp_id = doc["material_id"]
1392+
meta[mp_id].append({"task_id": task_id, "calc_type": calc_type})
1393+
14361394
if not meta:
14371395
raise ValueError(f"No tasks found for material id {material_ids}.")
14381396

mp_api/client/routes/materials/electronic_structure.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -276,61 +276,47 @@ def get_bandstructure_from_material_id(
276276
if not bs_doc:
277277
raise MPRestError("No electronic structure data found.")
278278

279-
bs_data = (
280-
bs_doc[0].bandstructure # type: ignore
281-
if self.use_document_model
282-
else bs_doc[0]["bandstructure"] # type: ignore
283-
)
284-
285-
if bs_data is None:
279+
if (bs_data := bs_doc[0]["bandstructure"]) is None:
286280
raise MPRestError(
287281
f"No {path_type.value} band structure data found for {material_id}"
288282
)
289-
else:
290-
bs_data: dict = (
291-
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
292-
)
293283

294-
if bs_data.get(path_type.value, None):
295-
bs_task_id = bs_data[path_type.value]["task_id"]
296-
else:
284+
bs_data: dict = (
285+
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
286+
)
287+
288+
if bs_data.get(path_type.value, None) is None:
297289
raise MPRestError(
298290
f"No {path_type.value} band structure data found for {material_id}"
299291
)
300-
else:
301-
bs_doc = es_rester.search(material_ids=material_id, fields=["dos"])
292+
bs_task_id = bs_data[path_type.value]["task_id"]
302293

303-
if not bs_doc:
294+
else:
295+
if not (
296+
bs_doc := es_rester.search(material_ids=material_id, fields=["dos"])
297+
):
304298
raise MPRestError("No electronic structure data found.")
305299

306-
bs_data = (
307-
bs_doc[0].dos # type: ignore
308-
if self.use_document_model
309-
else bs_doc[0]["dos"] # type: ignore
310-
)
311-
312-
if bs_data is None:
300+
if (bs_data := bs_doc[0]["dos"]) is None:
313301
raise MPRestError(
314302
f"No uniform band structure data found for {material_id}"
315303
)
316-
else:
317-
bs_data: dict = (
318-
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
319-
)
320304

321-
if bs_data.get("total", None):
322-
bs_task_id = bs_data["total"]["1"]["task_id"]
323-
else:
305+
bs_data: dict = (
306+
bs_data.model_dump() if self.use_document_model else bs_data # type: ignore
307+
)
308+
309+
if bs_data.get("total", None) is None:
324310
raise MPRestError(
325311
f"No uniform band structure data found for {material_id}"
326312
)
313+
bs_task_id = bs_data["total"]["1"]["task_id"]
327314

328315
bs_obj = self.get_bandstructure_from_task_id(bs_task_id)
329316

330317
if bs_obj:
331318
return bs_obj
332-
else:
333-
raise MPRestError("No band structure object found.")
319+
raise MPRestError("No band structure object found.")
334320

335321

336322
class DosRester(BaseRester):
@@ -456,22 +442,16 @@ def get_dos_from_material_id(self, material_id: str):
456442
mute_progress_bars=self.mute_progress_bars,
457443
)
458444

459-
dos_doc = es_rester.search(material_ids=material_id, fields=["dos"])
460-
if not dos_doc:
445+
if not (dos_doc := es_rester.search(material_ids=material_id, fields=["dos"])):
461446
return None
462447

463-
dos_data: dict = (
464-
dos_doc[0].model_dump() if self.use_document_model else dos_doc[0] # type: ignore
465-
)
466-
467-
if dos_data["dos"]:
468-
dos_task_id = dos_data["dos"]["total"]["1"]["task_id"]
469-
else:
448+
if not (dos_data := dos_doc[0].get("dos")):
470449
raise MPRestError(f"No density of states data found for {material_id}")
471450

472-
dos_obj = self.get_dos_from_task_id(dos_task_id)
473-
474-
if dos_obj:
451+
dos_task_id = (dos_data.model_dump() if self.use_document_model else dos_data)[
452+
"total"
453+
]["1"]["task_id"]
454+
if dos_obj := self.get_dos_from_task_id(dos_task_id):
475455
return dos_obj
476-
else:
477-
raise MPRestError("No density of states object found.")
456+
457+
raise MPRestError("No density of states object found.")

0 commit comments

Comments
 (0)