Skip to content

Commit ac59383

Browse files
revise lazy loading, correct rester init
1 parent e31435a commit ac59383

10 files changed

Lines changed: 196 additions & 182 deletions

File tree

mp_api/client/core/client.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from typing import Any, Callable
5050

5151
from pydantic.fields import FieldInfo
52+
from mp_api.client.core.utils import LazyImport
5253

5354
try:
5455
__version__ = version("mp_api")
@@ -1352,6 +1353,40 @@ def __str__(self): # pragma: no cover
13521353
f"Available fields: {', '.join(self.available_fields)}\n\n"
13531354
)
13541355

1356+
class CoreRester(BaseRester):
1357+
"""Define a BaseRester with extra features for core resters.
1358+
1359+
Enables lazy importing / initialization of sub resters
1360+
provided in `_sub_resters`, which should be a map
1361+
of endpoints names to LazyImport objects.
1362+
1363+
"""
1364+
_sub_resters : dict[str,LazyImport] = {}
1365+
1366+
def __getattr__(self, v: str):
1367+
if v in self._sub_resters:
1368+
if self._sub_resters[v]._obj is None:
1369+
# TODO: Enable monty decoding when tasks and SNL schema is normalized
1370+
monty_disable = self._sub_resters[v]._class_name in [
1371+
"TaskRester",
1372+
"ProvenanceRester",
1373+
]
1374+
1375+
self._sub_resters[v](
1376+
api_key=self.api_key,
1377+
endpoint=self.endpoint.split(self.suffix)[0],
1378+
include_user_agent=self._include_user_agent,
1379+
session=self.session,
1380+
monty_decode=False if monty_disable else self.monty_decode,
1381+
use_document_model=self.use_document_model,
1382+
headers=self.headers,
1383+
mute_progress_bars=self.mute_progress_bars,
1384+
)
1385+
return self._sub_resters[v]
1386+
1387+
def __dir__(self):
1388+
return dir(self.__class__) + list(self._sub_resters)
1389+
13551390

13561391
class MPRestError(Exception):
13571392
"""Raised when the query has problems, e.g., bad query format."""

mp_api/client/core/utils.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -76,35 +76,76 @@ def validate_ids(id_list: list[str]) -> list[str]:
7676

7777

7878
class LazyImport:
79-
"""Lazily import and load an object."""
79+
"""Lazily import and load an object.
80+
81+
This class is super lazy, in that it lazily imports and caches an object.
82+
If the object is a function, the function itself will be cached.
8083
81-
__slots__ = ["_module_name", "_class_name", "_obj"]
84+
If the object is a class, and the class is initialized, the
85+
current instance of the class will be cached.
8286
83-
def __init__( # noqa: D107
84-
self,
85-
module_name: str,
86-
class_name: str,
87-
) -> None:
88-
self._module_name = module_name
89-
self._class_name = class_name
90-
self._obj = None
87+
Parameters
88+
-----------
89+
import_str : str
90+
A dot-separated, import-like string.
91+
"""
92+
93+
__slots__ = ["_module_name", "_class_name", "_obj","_imported"]
94+
95+
def __init__(self, import_str: str,) -> None:
96+
if len(
97+
splitted := import_str.rsplit(".",1)
98+
) > 1:
99+
self._module_name, self._class_name = splitted
100+
else:
101+
self._module_name = splitted[0]
102+
self._class_name = None
103+
104+
self._imported : Any | None = None
105+
self._obj : Any | None = None
91106

92107
def __str__(self) -> str:
93-
return f"LazyImport of {self._module_name}.{self._class_name}"
108+
return f"LazyImport of {self._module_name}" + (f".{self._class_name}" if self._class_name else "")
94109

95110
def __repr__(self) -> str:
96111
return self.__str__()
97112

98-
def __call__(self, *args, **kwargs):
99-
if self._obj is None:
100-
try:
101-
self._obj = getattr(import_module(self._module_name), self._class_name)(
102-
*args,
103-
**kwargs,
104-
)
105-
except Exception as exc:
106-
raise ImportError(f"Failed to import {self._class_name}:\n{exc}")
107-
return self._obj
108-
109-
def __getattr__(self, v):
110-
return getattr(self._obj, v)
113+
def _load(self,) -> None:
114+
try:
115+
_imported = import_module(self._module_name)
116+
if self._class_name:
117+
_imported = getattr(_imported, self._class_name)
118+
self._imported = _imported
119+
except Exception as exc:
120+
raise ImportError(f"Failed to import {self._module_name}.{self._class_name}:\n{exc}")
121+
122+
def __call__(self, *args, **kwargs) -> Any:
123+
"""Call a function or (re-)initialize a class.
124+
125+
If the object itself has not been imported, this will first import it.
126+
127+
If the object is a class, it will be initialized, cached, and returned.
128+
129+
If the object is a function, it will be cached, and this will return
130+
the value(s) of the function at (*args,**kwargs).
131+
"""
132+
if self._imported is None:
133+
self._load()
134+
135+
if isinstance(self._imported,type):
136+
self._obj = self._imported(*args, **kwargs)
137+
return self._obj
138+
else:
139+
self._obj = self._imported
140+
return self._obj(*args,**kwargs)
141+
142+
def __getattr__(self, v : str) -> Any:
143+
"""Get an attribute on a super lazy object."""
144+
if self._obj is not None and hasattr(self._obj, v):
145+
return getattr(self._obj, v)
146+
147+
if self._imported is None:
148+
self._load()
149+
if hasattr(self._imported,v):
150+
return getattr(self._imported,v)
151+

mp_api/client/mprester.py

Lines changed: 10 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
from mp_api.client.core.settings import MAPIClientSettings
2929
from mp_api.client.core.utils import LazyImport, load_json, validate_ids
3030
from mp_api.client.routes import GENERIC_RESTERS
31-
from mp_api.client.routes.materials.materials import MATERIALS_RESTERS
31+
from mp_api.client.routes.materials import MATERIALS_RESTERS
32+
from mp_api.client.routes.molecules import MOLECULES_RESTERS
33+
3234

3335
if TYPE_CHECKING:
3436
from typing import Any, Literal
@@ -41,15 +43,20 @@
4143
DEFAULT_THERMOTYPE_CRITERIA = {"thermo_types": ["GGA_GGA+U"]}
4244

4345
RESTER_LAYOUT = {
46+
"molecules/core": LazyImport("mp_api.client.routes.molecules.molecules.MoleculeRester"),
47+
"materials/core": MATERIALS_RESTERS["materials"],
4448
**{
4549
f"materials/{k}": v
4650
for k, v in MATERIALS_RESTERS.items()
4751
if k not in {"materials", "doi"}
4852
},
49-
"materials/core": MATERIALS_RESTERS["materials"],
5053
"doi": MATERIALS_RESTERS["doi"],
54+
**{
55+
f"molecules/{k}": v
56+
for k, v in MOLECULES_RESTERS.items()
57+
if k not in {"molecules",}
58+
},
5159
**GENERIC_RESTERS,
52-
"molecules/core": LazyImport("mp_api.client.routes.molecules", "MoleculeRester"),
5360
}
5461

5562

@@ -200,82 +207,10 @@ def __init__(
200207
}
201208

202209
# Set remaining top level resters, or get an attribute-class name mapping
203-
# for all sub-resters
204-
_sub_rester_suffix_map = {"materials": {}, "molecules": {}}
205-
206-
# for cls in self._all_resters:
207-
# if cls.suffix not in core_suffix:
208-
# suffix_split = cls.suffix.split("/")
209-
210-
# if len(suffix_split) == 1:
211-
# # Disable monty decode on nested data which may give errors
212-
# monty_disable = cls.__name__ in ["TaskRester", "ProvenanceRester"]
213-
# monty_decode = False if monty_disable else self.monty_decode
214-
# rester = cls(
215-
# api_key=api_key,
216-
# endpoint=self.endpoint,
217-
# include_user_agent=include_user_agent,
218-
# session=self.session,
219-
# monty_decode=monty_decode,
220-
# use_document_model=self.use_document_model,
221-
# headers=self.headers,
222-
# mute_progress_bars=self.mute_progress_bars,
223-
# ) # type: BaseRester
224-
# setattr(
225-
# self,
226-
# suffix_split[0],
227-
# rester,
228-
# )
229-
# else:
230-
# attr = "_".join(suffix_split[1:])
231-
# if "materials" in suffix_split:
232-
# _sub_rester_suffix_map["materials"][attr] = cls
233-
# elif "molecules" in suffix_split:
234-
# _sub_rester_suffix_map["molecules"][attr] = cls
235-
236-
# TODO: Enable monty decoding when tasks and SNL schema is normalized
237-
#
238-
# Allow lazy loading of nested resters under materials and molecules using custom __getattr__ methods
239-
def __core_custom_getattr(_self, _attr, _rester_map):
240-
if _attr in RESTER_LAYOUT:
241-
lazy_rester = RESTER_LAYOUT[_attr]
242-
monty_disable = lazy_rester._class_name in [
243-
"TaskRester",
244-
"ProvenanceRester",
245-
]
246-
monty_decode = False if monty_disable else self.monty_decode
247-
rester = lazy_rester(
248-
api_key=api_key,
249-
endpoint=self.endpoint,
250-
include_user_agent=include_user_agent,
251-
session=self.session,
252-
monty_decode=monty_decode,
253-
use_document_model=self.use_document_model,
254-
headers=self.headers,
255-
mute_progress_bars=self.mute_progress_bars,
256-
) # type: BaseRester
257-
258-
return rester
259-
else:
260-
raise AttributeError(
261-
f"{_self.__class__.__name__!r} object has no attribute {_attr!r}"
262-
)
263-
264-
def __materials_getattr__(_self, attr):
265-
rester = __core_custom_getattr(_self, attr, MATERIALS_RESTERS)
266-
return rester
267-
268-
def __molecules_getattr__(_self, attr):
269-
_rester_map = _sub_rester_suffix_map["molecules"]
270-
rester = __core_custom_getattr(_self, attr, _rester_map)
271-
return rester
272210

273211
for attr, rester in core_resters.items():
274212
setattr(self, attr, rester)
275213

276-
# self.materials.__getattr__ = __materials_getattr__ # type: ignore
277-
# MoleculeRester.__getattr__ = __molecules_getattr__ # type: ignore
278-
279214
@property
280215
def contribs(self):
281216
if self._contribs is None:

mp_api/client/routes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from mp_api.client.core.utils import LazyImport
44

55
GENERIC_RESTERS = {
6-
k: LazyImport(f"mp_api.client.routes.{k}", v)
6+
k: LazyImport(f"mp_api.client.routes.{k}.{v}")
77
for k, v in {
88
"_general_store": "GeneralStoreRester",
99
"_messages": "MessagesRester",

mp_api/client/routes/materials/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
from mp_api.client.core.utils import LazyImport
55

66
MATERIALS_RESTERS: dict[str, LazyImport] = {
7-
route: LazyImport(f"mp_api.client.routes.materials.{module_name}", cls_name)
8-
for route, module_name, cls_name in [
7+
route: LazyImport(f"mp_api.client.routes.materials.{module_name}.{cls_name}")
8+
for route, module_name, cls_name in (
99
("absorption", "absorption", "AbsorptionRester"),
1010
("alloys", "alloys", "AlloysRester"),
1111
("bonds", "bonds", "BondsRester"),
12-
("charge_density", "charge_density", "ChargeDensityRester"),
1312
(
1413
"chemenv",
1514
"chemenv",
@@ -33,7 +32,7 @@
3332
("materials", "materials", "MaterialsRester"),
3433
("oxidation_states", "oxidation_states", "OxidationStatesRester"),
3534
("phonon", "phonon", "PhononRester"),
36-
("piezoelectric", "piezoelectric", "PiezoRester"),
35+
("piezoelectric", "piezo", "PiezoRester"),
3736
("provenance", "provenance", "ProvenanceRester"),
3837
("robocrys", "robocrys", "RobocrysRester"),
3938
("similarity", "similarity", "SimilarityRester"),
@@ -44,5 +43,5 @@
4443
("tasks", "tasks", "TaskRester"),
4544
("thermo", "thermo", "ThermoRester"),
4645
("xas", "xas", "XASRester"),
47-
]
46+
)
4847
}

mp_api/client/routes/materials/materials.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,71 +5,19 @@
55
from emmet.core.vasp.material import MaterialsDoc
66
from pymatgen.core.structure import Structure
77

8-
from mp_api.client.core import BaseRester, MPRestError
8+
from mp_api.client.core.client import CoreRester, MPRestError
99
from mp_api.client.core.utils import validate_ids
1010
from mp_api.client.routes.materials import MATERIALS_RESTERS
1111

1212
_EMMET_SETTINGS = EmmetSettings() # type: ignore
1313

1414

15-
class MaterialsRester(BaseRester):
15+
class MaterialsRester(CoreRester):
1616
suffix = "materials/core"
1717
document_model = MaterialsDoc # type: ignore
1818
supports_versions = True
1919
primary_key = "material_id"
20-
_sub_resters = [
21-
"eos",
22-
"similarity",
23-
"tasks",
24-
"xas",
25-
"grain_boundaries",
26-
"substrates",
27-
"surface_properties",
28-
"phonon",
29-
"elasticity",
30-
"thermo",
31-
"dielectric",
32-
"piezoelectric",
33-
"magnetism",
34-
"summary",
35-
"robocrys",
36-
"synthesis",
37-
"insertion_electrodes",
38-
"conversion_electrodes",
39-
"electronic_structure",
40-
"electronic_structure_bandstructure",
41-
"electronic_structure_dos",
42-
"oxidation_states",
43-
"provenance",
44-
"bonds",
45-
"alloys",
46-
"absorption",
47-
"chemenv",
48-
]
49-
50-
def __getattr__(self, v: str):
51-
if v in self._sub_resters:
52-
if MATERIALS_RESTERS[v]._obj is None:
53-
# TODO: Enable monty decoding when tasks and SNL schema is normalized
54-
monty_disable = MATERIALS_RESTERS[v]._class_name in [
55-
"TaskRester",
56-
"ProvenanceRester",
57-
]
58-
59-
MATERIALS_RESTERS[v](
60-
api_key=self.api_key,
61-
endpoint=self.endpoint.split(self.suffix)[0],
62-
include_user_agent=self._include_user_agent,
63-
session=self.session,
64-
monty_decode=False if monty_disable else self.monty_decode,
65-
use_document_model=self.use_document_model,
66-
headers=self.headers,
67-
mute_progress_bars=self.mute_progress_bars,
68-
)
69-
return MATERIALS_RESTERS[v]
70-
71-
def __dir__(self):
72-
return dir(MaterialsRester) + self._sub_resters
20+
_sub_resters = MATERIALS_RESTERS
7321

7422
def get_structure_by_material_id(
7523
self, material_id: str, final: bool = True

0 commit comments

Comments
 (0)