Skip to content

Commit 73e9a50

Browse files
precomit
1 parent ac59383 commit 73e9a50

6 files changed

Lines changed: 66 additions & 44 deletions

File tree

mp_api/client/core/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from typing import Any, Callable
5050

5151
from pydantic.fields import FieldInfo
52+
5253
from mp_api.client.core.utils import LazyImport
5354

5455
try:
@@ -1353,6 +1354,7 @@ def __str__(self): # pragma: no cover
13531354
f"Available fields: {', '.join(self.available_fields)}\n\n"
13541355
)
13551356

1357+
13561358
class CoreRester(BaseRester):
13571359
"""Define a BaseRester with extra features for core resters.
13581360
@@ -1361,7 +1363,8 @@ class CoreRester(BaseRester):
13611363
of endpoints names to LazyImport objects.
13621364
13631365
"""
1364-
_sub_resters : dict[str,LazyImport] = {}
1366+
1367+
_sub_resters: dict[str, LazyImport] = {}
13651368

13661369
def __getattr__(self, v: str):
13671370
if v in self._sub_resters:

mp_api/client/core/utils.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def validate_ids(id_list: list[str]) -> list[str]:
7777

7878
class LazyImport:
7979
"""Lazily import and load an object.
80-
80+
8181
This class is super lazy, in that it lazily imports and caches an object.
8282
If the object is a function, the function itself will be cached.
8383
@@ -90,38 +90,52 @@ class LazyImport:
9090
A dot-separated, import-like string.
9191
"""
9292

93-
__slots__ = ["_module_name", "_class_name", "_obj","_imported"]
93+
__slots__ = ["_module_name", "_class_name", "_obj", "_imported"]
94+
95+
def __init__(
96+
self,
97+
import_str: str,
98+
) -> None:
99+
"""Initialize a lazily imported object.
94100
95-
def __init__(self, import_str: str,) -> None:
96-
if len(
97-
splitted := import_str.rsplit(".",1)
98-
) > 1:
99-
self._module_name, self._class_name = splitted
101+
Parameters
102+
-----------
103+
import_str : str
104+
A dot-separated, import-like string.
105+
"""
106+
if len(split_import_str := import_str.rsplit(".", 1)) > 1:
107+
self._module_name, self._class_name = split_import_str
100108
else:
101-
self._module_name = splitted[0]
109+
self._module_name = split_import_str[0]
102110
self._class_name = None
103111

104-
self._imported : Any | None = None
105-
self._obj : Any | None = None
112+
self._imported: Any | None = None
113+
self._obj: Any | None = None
106114

107115
def __str__(self) -> str:
108-
return f"LazyImport of {self._module_name}" + (f".{self._class_name}" if self._class_name else "")
116+
return f"LazyImport of {self._module_name}" + (
117+
f".{self._class_name}" if self._class_name else ""
118+
)
109119

110120
def __repr__(self) -> str:
111121
return self.__str__()
112122

113-
def _load(self,) -> None:
123+
def _load(
124+
self,
125+
) -> None:
114126
try:
115127
_imported = import_module(self._module_name)
116128
if self._class_name:
117129
_imported = getattr(_imported, self._class_name)
118130
self._imported = _imported
119131
except Exception as exc:
120-
raise ImportError(f"Failed to import {self._module_name}.{self._class_name}:\n{exc}")
132+
raise ImportError(
133+
f"Failed to import {self._module_name}.{self._class_name}:\n{exc}"
134+
)
121135

122136
def __call__(self, *args, **kwargs) -> Any:
123137
"""Call a function or (re-)initialize a class.
124-
138+
125139
If the object itself has not been imported, this will first import it.
126140
127141
If the object is a class, it will be initialized, cached, and returned.
@@ -132,20 +146,19 @@ def __call__(self, *args, **kwargs) -> Any:
132146
if self._imported is None:
133147
self._load()
134148

135-
if isinstance(self._imported,type):
149+
if isinstance(self._imported, type):
136150
self._obj = self._imported(*args, **kwargs)
137151
return self._obj
138152
else:
139153
self._obj = self._imported
140-
return self._obj(*args,**kwargs)
154+
return self._obj(*args, **kwargs)
141155

142-
def __getattr__(self, v : str) -> Any:
156+
def __getattr__(self, v: str) -> Any:
143157
"""Get an attribute on a super lazy object."""
144158
if self._obj is not None and hasattr(self._obj, v):
145159
return getattr(self._obj, v)
146160

147161
if self._imported is None:
148162
self._load()
149-
if hasattr(self._imported,v):
150-
return getattr(self._imported,v)
151-
163+
if hasattr(self._imported, v):
164+
return getattr(self._imported, v)

mp_api/client/mprester.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from mp_api.client.routes.materials import MATERIALS_RESTERS
3232
from mp_api.client.routes.molecules import MOLECULES_RESTERS
3333

34-
3534
if TYPE_CHECKING:
3635
from typing import Any, Literal
3736

@@ -43,7 +42,9 @@
4342
DEFAULT_THERMOTYPE_CRITERIA = {"thermo_types": ["GGA_GGA+U"]}
4443

4544
RESTER_LAYOUT = {
46-
"molecules/core": LazyImport("mp_api.client.routes.molecules.molecules.MoleculeRester"),
45+
"molecules/core": LazyImport(
46+
"mp_api.client.routes.molecules.molecules.MoleculeRester"
47+
),
4748
"materials/core": MATERIALS_RESTERS["materials"],
4849
**{
4950
f"materials/{k}": v
@@ -54,7 +55,10 @@
5455
**{
5556
f"molecules/{k}": v
5657
for k, v in MOLECULES_RESTERS.items()
57-
if k not in {"molecules",}
58+
if k
59+
not in {
60+
"molecules",
61+
}
5862
},
5963
**GENERIC_RESTERS,
6064
}

mp_api/client/routes/molecules/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from mp_api.client.core.utils import LazyImport
44

55
MOLECULES_RESTERS = {
6-
k : LazyImport(f"mp_api.client.routes.molecules.{k}.{v}")
6+
k: LazyImport(f"mp_api.client.routes.molecules.{k}.{v}")
77
for k, v in (
88
("molecules", "MoleculeRester"),
99
("jcser", "JcesrMoleculesRester"),
10-
("summary", "MoleculesSummaryRester")
10+
("summary", "MoleculesSummaryRester"),
1111
)
12-
}
12+
}

tests/core/test_utils.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
21
from mp_api.client.core.utils import LazyImport
32

4-
def test_lazy_import_function():
53

4+
def test_lazy_import_function():
65
import_str = "json.dumps"
76
lazy_func = LazyImport(import_str)
87
assert lazy_func._module_name == "json"
@@ -11,24 +10,21 @@ def test_lazy_import_function():
1110

1211
jsonables = [
1312
{"apple": "pineapple", "banana": "orange"},
14-
[1,2,3,4,5],
15-
[{"nothing": {"of": {"grand": "import"}}}]
13+
[1, 2, 3, 4, 5],
14+
[{"nothing": {"of": {"grand": "import"}}}],
1615
]
1716

18-
dumped = [
19-
lazy_func(jsonable) for jsonable in jsonables
20-
]
17+
dumped = [lazy_func(jsonable) for jsonable in jsonables]
2118

2219
import json as _real_json
23-
20+
2421
assert lazy_func._imported == _real_json.dumps
2522
assert all(
26-
dumped[i] == _real_json.dumps(jsonable)
27-
for i, jsonable in enumerate(jsonables)
23+
dumped[i] == _real_json.dumps(jsonable) for i, jsonable in enumerate(jsonables)
2824
)
2925

30-
def test_lazy_import_class():
3126

27+
def test_lazy_import_class():
3228
import_str = "pymatgen.core.Structure"
3329
lazy_class = LazyImport(import_str)
3430
assert lazy_class._module_name == "pymatgen.core"
@@ -45,12 +41,15 @@ def test_lazy_import_class():
4541
direct
4642
0.8750000000000000 0.8750000000000000 0.8750000000000000 Si
4743
0.1250000000000000 0.1250000000000000 0.1250000000000000 Si"""
48-
44+
4945
# test construction from classmethod
50-
struct_from_str = lazy_class.from_str(structure_str, fmt = "poscar")
46+
struct_from_str = lazy_class.from_str(structure_str, fmt="poscar")
5147
# test re-init
52-
struct_from_init = lazy_class(struct_from_str.lattice,struct_from_str.species,struct_from_str.frac_coords)
48+
struct_from_init = lazy_class(
49+
struct_from_str.lattice, struct_from_str.species, struct_from_str.frac_coords
50+
)
5351
assert struct_from_str == struct_from_init
5452

5553
from pymatgen.core.structure import Structure
56-
assert Structure.from_str(structure_str,fmt="poscar") == struct_from_str
54+
55+
assert Structure.from_str(structure_str, fmt="poscar") == struct_from_str

tests/test_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@
4949

5050
# Temporarily ignore molecules resters while molecules query operators are changed
5151
resters_to_test = [
52-
rester for rester in mpr._all_resters if "molecule" not in rester._class_name.lower()
52+
rester
53+
for rester in mpr._all_resters
54+
if "molecule" not in rester._class_name.lower()
5355
]
5456

57+
5558
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
5659
@pytest.mark.parametrize("rester", resters_to_test)
5760
def test_generic_get_methods(rester):
@@ -65,7 +68,7 @@ def test_generic_get_methods(rester):
6568
monty_decode=rester not in [TaskRester, ProvenanceRester],
6669
use_document_model=True,
6770
)
68-
71+
6972
name = rester.suffix.replace("/", "_")
7073

7174
docs_check = lambda _docs: all(

0 commit comments

Comments
 (0)