Skip to content

Commit 300a2fa

Browse files
draft lazy loading
1 parent a8210aa commit 300a2fa

5 files changed

Lines changed: 173 additions & 310 deletions

File tree

mp_api/client/core/utils.py

Lines changed: 45 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from __future__ import annotations
22

3-
import re
3+
from importlib import import_module
44
from typing import TYPE_CHECKING, Literal
55

66
import orjson
77
from emmet.core import __version__ as _EMMET_CORE_VER
8+
from emmet.core.mpid_ext import validate_identifier
89
from monty.json import MontyDecoder
910
from packaging.version import parse as parse_version
1011

1112
from mp_api.client.core.settings import MAPIClientSettings
1213

1314
if TYPE_CHECKING:
14-
from monty.json import MSONable
15+
from typing import Any
1516

1617

1718
def _compare_emmet_ver(
@@ -23,6 +24,10 @@ def _compare_emmet_ver(
2324
_compare_emmet_ver("0.84.0rc0","<") returns
2425
emmet.core.__version__ < "0.84.0rc0"
2526
27+
This function may not be used anywhere in the client, but it should
28+
be preserved for future use, in case some degree of backwards
29+
compatibility or feature buy-in is needed.
30+
2631
Parameters
2732
-----------
2833
ref_version : str
@@ -36,41 +41,17 @@ def _compare_emmet_ver(
3641
)(parse_version(ref_version))
3742

3843

39-
if _compare_emmet_ver("0.85.0", ">="):
40-
from emmet.core.mpid_ext import validate_identifier
41-
else:
42-
validate_identifier = None
43-
44-
45-
def load_json(json_like: str | bytes, deser: bool = False, encoding: str = "utf-8"):
44+
def load_json(
45+
json_like: str | bytes, deser: bool = False, encoding: str = "utf-8"
46+
) -> Any:
4647
"""Utility to load json in consistent manner."""
4748
data = orjson.loads(
4849
json_like if isinstance(json_like, bytes) else json_like.encode(encoding)
4950
)
5051
return MontyDecoder().process_decoded(data) if deser else data
5152

5253

53-
def _legacy_id_validation(id_list: list[str]) -> list[str]:
54-
"""Legacy utility to validate IDs, pre-AlphaID transition.
55-
56-
This function is temporarily maintained to allow for
57-
backwards compatibility with older versions of emmet, and will
58-
not be preserved.
59-
"""
60-
pattern = "(mp|mvc|mol|mpcule)-.*"
61-
if malformed_ids := {
62-
entry for entry in id_list if re.match(pattern, entry) is None
63-
}:
64-
raise ValueError(
65-
f"{'Entry' if len(malformed_ids) == 1 else 'Entries'}"
66-
f" {', '.join(malformed_ids)}"
67-
f"{'is' if len(malformed_ids) == 1 else 'are'} not formatted correctly!"
68-
)
69-
70-
return id_list
71-
72-
73-
def validate_ids(id_list: list[str]):
54+
def validate_ids(id_list: list[str]) -> list[str]:
7455
"""Function to validate material and task IDs.
7556
7657
Args:
@@ -91,36 +72,38 @@ def validate_ids(id_list: list[str]):
9172
# TODO: after the transition to AlphaID in the document models,
9273
# The following line should be changed to
9374
# return [validate_identifier(idx,serialize=True) for idx in id_list]
94-
if validate_identifier:
95-
return [str(validate_identifier(idx)) for idx in id_list]
96-
return _legacy_id_validation(id_list)
97-
98-
99-
def allow_msonable_dict(monty_cls: type[MSONable]):
100-
"""Patch Monty to allow for dict values for MSONable."""
101-
102-
def validate_monty(cls, v, _):
103-
"""Stub validator for MSONable as a dictionary only."""
104-
if isinstance(v, cls):
105-
return v
106-
elif isinstance(v, dict):
107-
# Just validate the simple Monty Dict Model
108-
errors = []
109-
if v.get("@module", "") != monty_cls.__module__:
110-
errors.append("@module")
111-
112-
if v.get("@class", "") != monty_cls.__name__:
113-
errors.append("@class")
114-
115-
if len(errors) > 0:
116-
raise ValueError(
117-
"Missing Monty seriailzation fields in dictionary: {errors}"
75+
return [str(validate_identifier(idx)) for idx in id_list]
76+
77+
class LazyImport:
78+
79+
__slots__ = ["_module_name", "_class_name", "_obj"]
80+
81+
def __init__(self, module_name : str, class_name : str,) -> None:
82+
self._module_name = module_name
83+
self._class_name = class_name
84+
self._obj = None
85+
86+
def __str__(self) -> str:
87+
return f"LazyImport of {self._module_name}.{self._class_name}"
88+
89+
def __repr__(self) -> str:
90+
return self.__str__()
91+
92+
def __call__(self, *args, **kwargs):
93+
if self._obj is None:
94+
try:
95+
self._obj = getattr(
96+
import_module(self._module_name),
97+
self._class_name
98+
)(
99+
*args,
100+
**kwargs,
118101
)
119-
120-
return v
121-
else:
122-
raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary")
123-
124-
monty_cls.validate_monty_v2 = classmethod(validate_monty)
125-
126-
return monty_cls
102+
except Exception as exc:
103+
raise ImportError(
104+
f"Failed to import {self._class_name}:\n{exc}"
105+
)
106+
return self._obj
107+
108+
def __getattr__(self,v):
109+
return getattr(self._obj,v)

0 commit comments

Comments
 (0)