Skip to content

Commit 066ca12

Browse files
ensure sub resters get unset on reinit
1 parent 5e3289d commit 066ca12

4 files changed

Lines changed: 26 additions & 5 deletions

File tree

mp_api/client/core/client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,10 +1350,15 @@ class CoreRester(BaseRester):
13501350

13511351
_sub_resters: dict[str, LazyImport] = {}
13521352

1353+
def __init__(self, **kwargs):
1354+
"""Ensure that sub resters are unset on re-init."""
1355+
super().__init__(**kwargs)
1356+
self.sub_resters = {k: v.copy() for k, v in self._sub_resters.items()}
1357+
13531358
def __getattr__(self, v: str):
1354-
if v in self._sub_resters:
1355-
if self._sub_resters[v]._obj is None:
1356-
self._sub_resters[v](
1359+
if v in self.sub_resters:
1360+
if self.sub_resters[v]._obj is None:
1361+
self.sub_resters[v](
13571362
api_key=self.api_key,
13581363
endpoint=self.base_endpoint,
13591364
include_user_agent=self._include_user_agent,
@@ -1362,7 +1367,7 @@ def __getattr__(self, v: str):
13621367
headers=self.headers,
13631368
mute_progress_bars=self.mute_progress_bars,
13641369
)
1365-
return self._sub_resters[v]
1370+
return self.sub_resters[v]
13661371

13671372
def __dir__(self):
13681373
return dir(self.__class__) + list(self._sub_resters)

mp_api/client/core/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ def __init__(
160160
self._imported: Any | None = None
161161
self._obj: Any | None = None
162162

163+
def copy(self) -> LazyImport:
164+
"""Return a new copy of the current instance."""
165+
return LazyImport(
166+
f"{self._module_name}"
167+
+ (f".{self._class_name}" if self._class_name else "")
168+
)
169+
163170
def __str__(self) -> str:
164171
return f"LazyImport of {self._module_name}" + (
165172
f".{self._class_name}" if self._class_name else ""

tests/core/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ def test_lazy_import_class():
7777

7878
assert Structure.from_str(structure_str, fmt="poscar") == struct_from_str
7979

80+
# ensure copy yields an independent object
81+
lazy_copy = lazy_class.copy()
82+
lazy_class(
83+
struct_from_str.lattice, struct_from_str.species, struct_from_str.frac_coords
84+
)
85+
assert lazy_copy._obj is None
86+
assert lazy_class._obj == struct_from_str
87+
8088

8189
def test_emmet_core_version_checks(monkeypatch: pytest.MonkeyPatch):
8290
ref_ver = (1, 2, "3rc5")

tests/molecules/test_molecules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ def test_molecule_rester():
1515
)
1616

1717
assert all(
18-
getattr(rester, k) == lazy_obj for k, lazy_obj in MOLECULES_RESTERS.items()
18+
getattr(rester, k)._class_name == lazy_obj._class_name
19+
for k, lazy_obj in MOLECULES_RESTERS.items()
1920
)

0 commit comments

Comments
 (0)