|
19 | 19 | from json import JSONDecodeError |
20 | 20 | from math import ceil |
21 | 21 | from typing import TYPE_CHECKING, ForwardRef, Optional, get_args |
22 | | -from urllib.parse import quote, urljoin |
| 22 | +from urllib.parse import quote |
23 | 23 |
|
24 | 24 | import requests |
25 | 25 | from emmet.core.utils import jsanitize |
|
31 | 31 | from urllib3.util.retry import Retry |
32 | 32 |
|
33 | 33 | from mp_api.client.core.exceptions import MPRestError |
34 | | -from mp_api.client.core.settings import MAPIClientSettings |
35 | | -from mp_api.client.core.utils import load_json, validate_ids |
| 34 | +from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS |
| 35 | +from mp_api.client.core.utils import ( |
| 36 | + load_json, |
| 37 | + validate_api_key, |
| 38 | + validate_endpoint, |
| 39 | + validate_ids, |
| 40 | +) |
36 | 41 |
|
37 | 42 | try: |
38 | 43 | import boto3 |
|
59 | 64 | __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION") |
60 | 65 |
|
61 | 66 |
|
62 | | -SETTINGS = MAPIClientSettings() # type: ignore |
63 | | - |
64 | | - |
65 | 67 | class _DictLikeAccess(BaseModel): |
66 | 68 | """Define a pydantic mix-in which permits dict-like access to model fields.""" |
67 | 69 |
|
@@ -98,7 +100,7 @@ def __init__( |
98 | 100 | use_document_model: bool = True, |
99 | 101 | timeout: int = 20, |
100 | 102 | headers: dict | None = None, |
101 | | - mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS, |
| 103 | + mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, |
102 | 104 | **kwargs, |
103 | 105 | ): |
104 | 106 | """Initialize the REST API helper class. |
@@ -132,23 +134,17 @@ def __init__( |
132 | 134 | mute_progress_bars: Whether to disable progress bars. |
133 | 135 | **kwargs: access to legacy kwargs that may be in the process of being deprecated |
134 | 136 | """ |
135 | | - # TODO: think about how to migrate from PMG_MAPI_KEY |
136 | | - self.api_key = api_key or os.getenv("MP_API_KEY") |
137 | | - self.base_endpoint = self.endpoint = endpoint or os.getenv( |
138 | | - "MP_API_ENDPOINT", "https://api.materialsproject.org/" |
139 | | - ) |
| 137 | + self.api_key = validate_api_key(api_key) |
| 138 | + self.base_endpoint = validate_endpoint(endpoint) |
| 139 | + self.endpoint = validate_endpoint(endpoint, suffix=self.suffix) |
| 140 | + |
140 | 141 | self.debug = debug |
141 | 142 | self.include_user_agent = include_user_agent |
142 | 143 | self.use_document_model = use_document_model |
143 | 144 | self.timeout = timeout |
144 | 145 | self.headers = headers or {} |
145 | 146 | self.mute_progress_bars = mute_progress_bars |
146 | | - self.db_version = BaseRester._get_database_version(self.endpoint) |
147 | | - |
148 | | - if self.suffix: |
149 | | - self.endpoint = urljoin(self.endpoint, self.suffix) |
150 | | - if not self.endpoint.endswith("/"): |
151 | | - self.endpoint += "/" |
| 147 | + self.db_version = BaseRester._get_database_version(self.base_endpoint) |
152 | 148 |
|
153 | 149 | self._session = session |
154 | 150 | self._s3_client = s3_client |
@@ -196,15 +192,14 @@ def _create_session(api_key, include_user_agent, headers): |
196 | 192 | user_agent = f"{mp_api_info} ({python_info} {platform_info})" |
197 | 193 | session.headers["user-agent"] = user_agent |
198 | 194 |
|
199 | | - settings = MAPIClientSettings() # type: ignore |
200 | | - max_retry_num = settings.MAX_RETRIES |
| 195 | + max_retry_num = MAPI_CLIENT_SETTINGS.MAX_RETRIES |
201 | 196 | retry = Retry( |
202 | 197 | total=max_retry_num, |
203 | 198 | read=max_retry_num, |
204 | 199 | connect=max_retry_num, |
205 | 200 | respect_retry_after_header=True, |
206 | 201 | status_forcelist=[429, 504, 502], # rate limiting |
207 | | - backoff_factor=settings.BACKOFF_FACTOR, |
| 202 | + backoff_factor=MAPI_CLIENT_SETTINGS.BACKOFF_FACTOR, |
208 | 203 | ) |
209 | 204 | adapter = HTTPAdapter(max_retries=retry) |
210 | 205 | session.mount("http://", adapter) |
@@ -265,11 +260,7 @@ def _post_resource( |
265 | 260 | payload = jsanitize(body) |
266 | 261 |
|
267 | 262 | try: |
268 | | - url = self.endpoint |
269 | | - if suburl: |
270 | | - url = urljoin(self.endpoint, suburl) |
271 | | - if not url.endswith("/"): |
272 | | - url += "/" |
| 263 | + url = validate_endpoint(self.endpoint, suffix=suburl) |
273 | 264 | response = self.session.post(url, json=payload, verify=True, params=params) |
274 | 265 |
|
275 | 266 | if response.status_code == 200: |
@@ -333,11 +324,7 @@ def _patch_resource( |
333 | 324 | payload = jsanitize(body) |
334 | 325 |
|
335 | 326 | try: |
336 | | - url = self.endpoint |
337 | | - if suburl: |
338 | | - url = urljoin(self.endpoint, suburl) |
339 | | - if not url.endswith("/"): |
340 | | - url += "/" |
| 327 | + url = validate_endpoint(self.endpoint, suffix=suburl) |
341 | 328 | response = self.session.patch(url, json=payload, verify=True, params=params) |
342 | 329 |
|
343 | 330 | if response.status_code == 200: |
@@ -469,11 +456,7 @@ def _query_resource( |
469 | 456 | criteria["_fields"] = ",".join(fields) |
470 | 457 |
|
471 | 458 | try: |
472 | | - url = self.endpoint |
473 | | - if suburl: |
474 | | - url = urljoin(self.endpoint, suburl) |
475 | | - if not url.endswith("/"): |
476 | | - url += "/" |
| 459 | + url = validate_endpoint(self.endpoint, suffix=suburl) |
477 | 460 |
|
478 | 461 | if query_s3: |
479 | 462 | db_version = self.db_version.replace(".", "-") |
@@ -620,15 +603,15 @@ def _submit_requests( # noqa |
620 | 603 |
|
621 | 604 | bare_url_len = len(url_string) |
622 | 605 | max_param_str_length = ( |
623 | | - MAPIClientSettings().MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore |
| 606 | + MAPI_CLIENT_SETTINGS.MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore |
624 | 607 | ) |
625 | 608 |
|
626 | 609 | # Next, check if default number of parallel requests works. |
627 | 610 | # If not, make slice size the minimum number of param entries |
628 | 611 | # contained in any substring of length max_param_str_length. |
629 | 612 | param_length = len(criteria[parallel_param].split(",")) |
630 | 613 | slice_size = ( |
631 | | - int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1 # type: ignore |
| 614 | + int(param_length / MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS) or 1 # type: ignore |
632 | 615 | ) |
633 | 616 |
|
634 | 617 | url_param_string = quote(criteria[parallel_param]) |
@@ -909,14 +892,14 @@ def _multi_thread( |
909 | 892 | params_ind = 0 |
910 | 893 |
|
911 | 894 | with ThreadPoolExecutor( |
912 | | - max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS # type: ignore |
| 895 | + max_workers=MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS # type: ignore |
913 | 896 | ) as executor: |
914 | 897 | # Get list of initial futures defined by max number of parallel requests |
915 | 898 | futures = set() |
916 | 899 |
|
917 | 900 | for params in itertools.islice( |
918 | 901 | params_gen, |
919 | | - MAPIClientSettings().NUM_PARALLEL_REQUESTS, # type: ignore |
| 902 | + MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS, # type: ignore |
920 | 903 | ): |
921 | 904 | future = executor.submit( |
922 | 905 | func, |
@@ -1278,7 +1261,7 @@ def _get_all_documents( |
1278 | 1261 | for key, entry in query_params.items() |
1279 | 1262 | if isinstance(entry, str) |
1280 | 1263 | and len(entry.split(",")) > 0 |
1281 | | - and key not in MAPIClientSettings().QUERY_NO_PARALLEL # type: ignore |
| 1264 | + and key not in MAPI_CLIENT_SETTINGS.QUERY_NO_PARALLEL # type: ignore |
1282 | 1265 | ), |
1283 | 1266 | key=lambda item: item[1], |
1284 | 1267 | reverse=True, |
@@ -1369,10 +1352,9 @@ class CoreRester(BaseRester): |
1369 | 1352 | def __getattr__(self, v: str): |
1370 | 1353 | if v in self._sub_resters: |
1371 | 1354 | if self._sub_resters[v]._obj is None: |
1372 | | - |
1373 | 1355 | self._sub_resters[v]( |
1374 | 1356 | api_key=self.api_key, |
1375 | | - endpoint=self.endpoint.split(self.suffix)[0], |
| 1357 | + endpoint=self.base_endpoint, |
1376 | 1358 | include_user_agent=self._include_user_agent, |
1377 | 1359 | session=self.session, |
1378 | 1360 | use_document_model=self.use_document_model, |
|
0 commit comments