Skip to content

Commit eaca434

Browse files
settings var + jcesr typo fix
1 parent 912174c commit eaca434

11 files changed

Lines changed: 82 additions & 67 deletions

File tree

mp_api/client/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
from .client import BaseRester
44
from .exceptions import MPRestError, MPRestWarning
5-
from .settings import MAPIClientSettings
5+
from .settings import MAPI_CLIENT_SETTINGS, MAPIClientSettings

mp_api/client/core/client.py

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from json import JSONDecodeError
2020
from math import ceil
2121
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
22-
from urllib.parse import quote, urljoin
22+
from urllib.parse import quote
2323

2424
import requests
2525
from emmet.core.utils import jsanitize
@@ -31,8 +31,13 @@
3131
from urllib3.util.retry import Retry
3232

3333
from mp_api.client.core.exceptions import MPRestError
34-
from mp_api.client.core.settings import MAPIClientSettings
35-
from mp_api.client.core.utils import load_json, validate_ids
34+
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
35+
from mp_api.client.core.utils import (
36+
load_json,
37+
validate_api_key,
38+
validate_endpoint,
39+
validate_ids,
40+
)
3641

3742
try:
3843
import boto3
@@ -59,9 +64,6 @@
5964
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION")
6065

6166

62-
SETTINGS = MAPIClientSettings() # type: ignore
63-
64-
6567
class _DictLikeAccess(BaseModel):
6668
"""Define a pydantic mix-in which permits dict-like access to model fields."""
6769

@@ -98,7 +100,7 @@ def __init__(
98100
use_document_model: bool = True,
99101
timeout: int = 20,
100102
headers: dict | None = None,
101-
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
103+
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
102104
**kwargs,
103105
):
104106
"""Initialize the REST API helper class.
@@ -132,23 +134,17 @@ def __init__(
132134
mute_progress_bars: Whether to disable progress bars.
133135
**kwargs: access to legacy kwargs that may be in the process of being deprecated
134136
"""
135-
# TODO: think about how to migrate from PMG_MAPI_KEY
136-
self.api_key = api_key or os.getenv("MP_API_KEY")
137-
self.base_endpoint = self.endpoint = endpoint or os.getenv(
138-
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
139-
)
137+
self.api_key = validate_api_key(api_key)
138+
self.base_endpoint = validate_endpoint(endpoint)
139+
self.endpoint = validate_endpoint(endpoint, suffix=self.suffix)
140+
140141
self.debug = debug
141142
self.include_user_agent = include_user_agent
142143
self.use_document_model = use_document_model
143144
self.timeout = timeout
144145
self.headers = headers or {}
145146
self.mute_progress_bars = mute_progress_bars
146-
self.db_version = BaseRester._get_database_version(self.endpoint)
147-
148-
if self.suffix:
149-
self.endpoint = urljoin(self.endpoint, self.suffix)
150-
if not self.endpoint.endswith("/"):
151-
self.endpoint += "/"
147+
self.db_version = BaseRester._get_database_version(self.base_endpoint)
152148

153149
self._session = session
154150
self._s3_client = s3_client
@@ -196,15 +192,14 @@ def _create_session(api_key, include_user_agent, headers):
196192
user_agent = f"{mp_api_info} ({python_info} {platform_info})"
197193
session.headers["user-agent"] = user_agent
198194

199-
settings = MAPIClientSettings() # type: ignore
200-
max_retry_num = settings.MAX_RETRIES
195+
max_retry_num = MAPI_CLIENT_SETTINGS.MAX_RETRIES
201196
retry = Retry(
202197
total=max_retry_num,
203198
read=max_retry_num,
204199
connect=max_retry_num,
205200
respect_retry_after_header=True,
206201
status_forcelist=[429, 504, 502], # rate limiting
207-
backoff_factor=settings.BACKOFF_FACTOR,
202+
backoff_factor=MAPI_CLIENT_SETTINGS.BACKOFF_FACTOR,
208203
)
209204
adapter = HTTPAdapter(max_retries=retry)
210205
session.mount("http://", adapter)
@@ -265,11 +260,7 @@ def _post_resource(
265260
payload = jsanitize(body)
266261

267262
try:
268-
url = self.endpoint
269-
if suburl:
270-
url = urljoin(self.endpoint, suburl)
271-
if not url.endswith("/"):
272-
url += "/"
263+
url = validate_endpoint(self.endpoint, suffix=suburl)
273264
response = self.session.post(url, json=payload, verify=True, params=params)
274265

275266
if response.status_code == 200:
@@ -333,11 +324,7 @@ def _patch_resource(
333324
payload = jsanitize(body)
334325

335326
try:
336-
url = self.endpoint
337-
if suburl:
338-
url = urljoin(self.endpoint, suburl)
339-
if not url.endswith("/"):
340-
url += "/"
327+
url = validate_endpoint(self.endpoint, suffix=suburl)
341328
response = self.session.patch(url, json=payload, verify=True, params=params)
342329

343330
if response.status_code == 200:
@@ -469,11 +456,7 @@ def _query_resource(
469456
criteria["_fields"] = ",".join(fields)
470457

471458
try:
472-
url = self.endpoint
473-
if suburl:
474-
url = urljoin(self.endpoint, suburl)
475-
if not url.endswith("/"):
476-
url += "/"
459+
url = validate_endpoint(self.endpoint, suffix=suburl)
477460

478461
if query_s3:
479462
db_version = self.db_version.replace(".", "-")
@@ -620,15 +603,15 @@ def _submit_requests( # noqa
620603

621604
bare_url_len = len(url_string)
622605
max_param_str_length = (
623-
MAPIClientSettings().MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore
606+
MAPI_CLIENT_SETTINGS.MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore
624607
)
625608

626609
# Next, check if default number of parallel requests works.
627610
# If not, make slice size the minimum number of param entries
628611
# contained in any substring of length max_param_str_length.
629612
param_length = len(criteria[parallel_param].split(","))
630613
slice_size = (
631-
int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1 # type: ignore
614+
int(param_length / MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS) or 1 # type: ignore
632615
)
633616

634617
url_param_string = quote(criteria[parallel_param])
@@ -909,14 +892,14 @@ def _multi_thread(
909892
params_ind = 0
910893

911894
with ThreadPoolExecutor(
912-
max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS # type: ignore
895+
max_workers=MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS # type: ignore
913896
) as executor:
914897
# Get list of initial futures defined by max number of parallel requests
915898
futures = set()
916899

917900
for params in itertools.islice(
918901
params_gen,
919-
MAPIClientSettings().NUM_PARALLEL_REQUESTS, # type: ignore
902+
MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS, # type: ignore
920903
):
921904
future = executor.submit(
922905
func,
@@ -1278,7 +1261,7 @@ def _get_all_documents(
12781261
for key, entry in query_params.items()
12791262
if isinstance(entry, str)
12801263
and len(entry.split(",")) > 0
1281-
and key not in MAPIClientSettings().QUERY_NO_PARALLEL # type: ignore
1264+
and key not in MAPI_CLIENT_SETTINGS.QUERY_NO_PARALLEL # type: ignore
12821265
),
12831266
key=lambda item: item[1],
12841267
reverse=True,
@@ -1369,10 +1352,9 @@ class CoreRester(BaseRester):
13691352
def __getattr__(self, v: str):
13701353
if v in self._sub_resters:
13711354
if self._sub_resters[v]._obj is None:
1372-
13731355
self._sub_resters[v](
13741356
api_key=self.api_key,
1375-
endpoint=self.endpoint.split(self.suffix)[0],
1357+
endpoint=self.base_endpoint,
13761358
include_user_agent=self._include_user_agent,
13771359
session=self.session,
13781360
use_document_model=self.use_document_model,

mp_api/client/core/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,6 @@ class MAPIClientSettings(BaseSettings):
9999
def _get_endpoint_from_env(cls, v: str | None) -> str:
100100
"""Support setting endpoint via MP_API_ENDPOINT environment variable."""
101101
return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT
102+
103+
104+
MAPI_CLIENT_SETTINGS = MAPIClientSettings()

mp_api/client/core/utils.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from importlib import import_module
55
from typing import TYPE_CHECKING, Literal
6+
from urllib.parse import urljoin
67

78
import orjson
89
from emmet.core import __version__ as _EMMET_CORE_VER
@@ -11,7 +12,7 @@
1112
from packaging.version import parse as parse_version
1213

1314
from mp_api.client.core.exceptions import MPRestError
14-
from mp_api.client.core.settings import MAPIClientSettings
15+
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
1516

1617
if TYPE_CHECKING:
1718
from typing import Any
@@ -58,9 +59,9 @@ def validate_api_key(api_key: str | None = None) -> str:
5859
# SETTINGS tries to read API key from ~/.config/.pmgrc.yaml
5960
api_key = api_key or os.getenv("MP_API_KEY")
6061
if not api_key:
61-
from pymatgen.core import SETTINGS
62+
from pymatgen.core import SETTINGS as PMG_SETTINGS
6263

63-
api_key = SETTINGS.get("PMG_MAPI_KEY")
64+
api_key = PMG_SETTINGS.get("PMG_MAPI_KEY")
6465

6566
if not api_key or len(api_key) != 32:
6667
addendum = " Valid API keys are 32 characters." if api_key else ""
@@ -84,7 +85,7 @@ def validate_ids(id_list: list[str]) -> list[str]:
8485
Returns:
8586
id_list: Returns original ID list if everything is formatted correctly.
8687
"""
87-
if len(id_list) > MAPIClientSettings().MAX_LIST_LENGTH:
88+
if len(id_list) > MAPI_CLIENT_SETTINGS.MAX_LIST_LENGTH:
8889
raise MPRestError(
8990
"List of material/molecule IDs provided is too long. Consider removing the ID filter to automatically pull"
9091
" data for all IDs and filter locally."
@@ -96,6 +97,32 @@ def validate_ids(id_list: list[str]) -> list[str]:
9697
return [str(validate_identifier(idx)) for idx in id_list]
9798

9899

100+
def validate_endpoint(endpoint: str | None, suffix: str | None = None) -> str:
101+
"""Validate an endpoint with optional suffix.
102+
103+
NB: does not modify the endpoint in place,
104+
returns a new variable.
105+
106+
Parameters
107+
-----------
108+
endpoint : str or None (default)
109+
A string representing the endpoint URL or the default
110+
in `mp_api.client.core.settings`
111+
suffix : str or None (default)
112+
Optional suffix to append to the endpoint.
113+
114+
Returns:
115+
-----------
116+
str : the validated endpoint
117+
"""
118+
new_endpoint = endpoint or MAPI_CLIENT_SETTINGS.ENDPOINT
119+
if suffix:
120+
new_endpoint = urljoin(new_endpoint, suffix)
121+
if not new_endpoint.endswith("/"):
122+
new_endpoint += "/"
123+
return new_endpoint
124+
125+
99126
class LazyImport:
100127
"""Lazily import and load an object.
101128

mp_api/client/mprester.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424

2525
from mp_api.client.core import BaseRester, MPRestError, MPRestWarning
2626
from mp_api.client.core._oxygen_evolution import OxygenEvolution
27-
from mp_api.client.core.settings import MAPIClientSettings
27+
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
2828
from mp_api.client.core.utils import (
2929
LazyImport,
3030
load_json,
3131
validate_api_key,
32+
validate_endpoint,
3233
validate_ids,
3334
)
3435
from mp_api.client.routes import GENERIC_RESTERS
@@ -42,7 +43,6 @@
4243
from pymatgen.entries.computed_entries import ComputedEntry
4344

4445
_EMMET_SETTINGS = EmmetSettings()
45-
_MAPI_SETTINGS = MAPIClientSettings()
4646
DEFAULT_THERMOTYPE_CRITERIA = {"thermo_types": ["GGA_GGA+U"]}
4747

4848
RESTER_LAYOUT = {
@@ -80,7 +80,7 @@ def __init__(
8080
use_document_model: bool = True,
8181
session: Session | None = None,
8282
headers: dict | None = None,
83-
mute_progress_bars: bool = _MAPI_SETTINGS.MUTE_PROGRESS_BARS,
83+
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
8484
**kwargs,
8585
):
8686
"""Initialize the MPRester.
@@ -119,7 +119,7 @@ def __init__(
119119
"""
120120
self.api_key = validate_api_key(api_key)
121121

122-
self.endpoint = endpoint or _MAPI_SETTINGS.ENDPOINT
122+
self.endpoint = validate_endpoint(endpoint) or MAPI_CLIENT_SETTINGS.ENDPOINT
123123
if not self.endpoint.endswith("/"):
124124
self.endpoint += "/"
125125

@@ -176,7 +176,7 @@ def __init__(
176176
emmet_version = MPRester.get_emmet_version(self.endpoint)
177177

178178
if version.parse(emmet_version.base_version) < version.parse(
179-
_MAPI_SETTINGS.MIN_EMMET_VERSION
179+
MAPI_CLIENT_SETTINGS.MIN_EMMET_VERSION
180180
):
181181
warnings.warn(
182182
"The installed version of the mp-api client may not be compatible with the API server. "

mp_api/client/routes/molecules/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
k: LazyImport(f"mp_api.client.routes.molecules.{k}.{v}")
77
for k, v in (
88
("molecules", "MoleculeRester"),
9-
("jcser", "JcesrMoleculesRester"),
9+
("jcesr", "JcesrMoleculesRester"),
1010
("summary", "MoleculesSummaryRester"),
1111
)
1212
}

mp_api/client/routes/molecules/jcesr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from collections import defaultdict
45

56
from emmet.core.molecules_jcesr import MoleculesDoc
@@ -14,6 +15,14 @@ class JcesrMoleculesRester(BaseRester):
1415
document_model = MoleculesDoc # type: ignore
1516
primary_key = "task_id"
1617

18+
def __init__(self, **kwargs):
19+
"""Throw deprecation warning when JCESR client is initialized."""
20+
warnings.warn(
21+
"NOTE: You are accessing the unmaintained legacy molecules data, "
22+
"please use MPRester.molecules.summary."
23+
)
24+
super().__init__(**kwargs)
25+
1726
def search(
1827
self,
1928
task_ids: str | list[str] | None = None,

mp_api/client/routes/molecules/molecules.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
import warnings
4-
53
from emmet.core.mpid import MPculeID
64
from emmet.core.qchem.molecule import MoleculeDoc
75
from emmet.core.settings import EmmetSettings
@@ -196,13 +194,5 @@ class MoleculeRester(BaseMoleculeRester):
196194
suffix = "molecules/core"
197195
_sub_resters = MOLECULES_RESTERS
198196

199-
def __getattr__(self, v: str):
200-
if "jcesr" in v:
201-
warnings.warn(
202-
"NOTE: You are accessing the unmaintained legacy molecules data, "
203-
"please use MPRester.molecules.summary."
204-
)
205-
super().__getattr__(v)
206-
207197
def __dir__(self):
208198
return dir(MoleculeRester) + self._sub_resters

tests/core/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def test_id_validation():
8585
from emmet.core.mpid import MPID, AlphaID
8686

8787
from mp_api.client.core.utils import validate_ids
88-
from mp_api.client.core.settings import MAPIClientSettings
88+
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
8989

90-
max_num_idxs = MAPIClientSettings().MAX_LIST_LENGTH
90+
max_num_idxs = MAPI_CLIENT_SETTINGS.MAX_LIST_LENGTH
9191

9292
with pytest.raises(MPRestError, match="too long"):
9393
_ = validate_ids([f"mp-{x}" for x in range(max_num_idxs + 1)])

tests/molecules/test_jcesr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,8 @@ def test_client(rester):
4848
custom_field_tests=custom_field_tests,
4949
sub_doc_fields=sub_doc_fields,
5050
)
51+
52+
53+
def test_warning():
54+
with pytest.warns(UserWarning, match="unmaintained legacy molecules"):
55+
JcesrMoleculesRester()

0 commit comments

Comments
 (0)