Skip to content

Commit 4217daf

Browse files
resolve merge conf
2 parents 7dcf96b + 30adeba commit 4217daf

23 files changed

Lines changed: 132 additions & 94 deletions

mp_api/client/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
22

3-
from .client import BaseRester, MPRestError, MPRestWarning
3+
from .client import BaseRester
4+
from .exceptions import MPRestError, MPRestWarning
45
from .settings import MAPIClientSettings

mp_api/client/core/client.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tqdm.auto import tqdm
3131
from urllib3.util.retry import Retry
3232

33+
from mp_api.client.core.exceptions import MPRestError
3334
from mp_api.client.core.settings import MAPIClientSettings
3435
from mp_api.client.core.utils import load_json, validate_ids
3536

@@ -92,11 +93,11 @@ def __init__(
9293
session: requests.Session | None = None,
9394
s3_client: Any | None = None,
9495
debug: bool = False,
95-
monty_decode: bool = True,
9696
use_document_model: bool = True,
9797
timeout: int = 20,
9898
headers: dict | None = None,
9999
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
100+
**kwargs,
100101
):
101102
"""Initialize the REST API helper class.
102103
@@ -121,13 +122,13 @@ def __init__(
121122
advanced usage only.
122123
s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores.
123124
debug: if True, print the URL for every request
124-
monty_decode: Decode the data using monty into python objects
125125
use_document_model: If False, skip the creating the document model and return data
126126
as a dictionary. This can be simpler to work with but bypasses data validation
127127
and will not give auto-complete for available fields.
128128
timeout: Time in seconds to wait until a request timeout error is thrown
129129
headers: Custom headers for localhost connections.
130130
mute_progress_bars: Whether to disable progress bars.
131+
**kwargs: access to legacy kwargs that may be in the process of being deprecated
131132
"""
132133
# TODO: think about how to migrate from PMG_MAPI_KEY
133134
self.api_key = api_key or os.getenv("MP_API_KEY")
@@ -136,7 +137,6 @@ def __init__(
136137
)
137138
self.debug = debug
138139
self.include_user_agent = include_user_agent
139-
self.monty_decode = monty_decode
140140
self.use_document_model = use_document_model
141141
self.timeout = timeout
142142
self.headers = headers or {}
@@ -151,6 +151,12 @@ def __init__(
151151
self._session = session
152152
self._s3_client = s3_client
153153

154+
if "monty_decode" in kwargs:
155+
warnings.warn(
156+
"Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`."
157+
"The client by default returns results consistent with `monty_decode=True`."
158+
)
159+
154160
@property
155161
def session(self) -> requests.Session:
156162
if not self._session:
@@ -265,7 +271,7 @@ def _post_resource(
265271
response = self.session.post(url, json=payload, verify=True, params=params)
266272

267273
if response.status_code == 200:
268-
data = load_json(response.text, deser=self.monty_decode)
274+
data = load_json(response.text)
269275
if self.document_model and use_document_model:
270276
if isinstance(data["data"], dict):
271277
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
@@ -333,7 +339,7 @@ def _patch_resource(
333339
response = self.session.patch(url, json=payload, verify=True, params=params)
334340

335341
if response.status_code == 200:
336-
data = load_json(response.text, deser=self.monty_decode)
342+
data = load_json(response.text)
337343
if self.document_model and use_document_model:
338344
if isinstance(data["data"], dict):
339345
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
@@ -384,10 +390,7 @@ def _query_open_data(
384390
Returns:
385391
dict: MontyDecoded data
386392
"""
387-
if not decoder:
388-
389-
def decoder(x):
390-
return load_json(x, deser=self.monty_decode)
393+
decoder = decoder or load_json
391394

392395
file = open(
393396
f"s3://{bucket}/{key}",
@@ -997,7 +1000,7 @@ def _submit_request_and_process(
9971000
)
9981001

9991002
if response.status_code == 200:
1000-
data = load_json(response.text, deser=self.monty_decode)
1003+
data = load_json(response.text)
10011004
# other sub-urls may use different document models
10021005
# the client does not handle this in a particularly smart way currently
10031006
if self.document_model and use_document_model:
@@ -1302,12 +1305,10 @@ def count(self, criteria: dict | None = None) -> int | str:
13021305
"""
13031306
criteria = criteria or {}
13041307
user_preferences = (
1305-
self.monty_decode,
13061308
self.use_document_model,
13071309
self.mute_progress_bars,
13081310
)
1309-
self.monty_decode, self.use_document_model, self.mute_progress_bars = (
1310-
False,
1311+
self.use_document_model, self.mute_progress_bars = (
13111312
False,
13121313
True,
13131314
) # do not waste cycles decoding
@@ -1329,7 +1330,6 @@ def count(self, criteria: dict | None = None) -> int | str:
13291330
)
13301331

13311332
(
1332-
self.monty_decode,
13331333
self.use_document_model,
13341334
self.mute_progress_bars,
13351335
) = user_preferences
@@ -1351,11 +1351,3 @@ def __str__(self): # pragma: no cover
13511351
f"{self.__class__.__name__} connected to {self.endpoint}\n\n"
13521352
f"Available fields: {', '.join(self.available_fields)}\n\n"
13531353
)
1354-
1355-
1356-
class MPRestError(Exception):
1357-
"""Raised when the query has problems, e.g., bad query format."""
1358-
1359-
1360-
class MPRestWarning(Warning):
1361-
"""Raised when a query is malformed but interpretable."""

mp_api/client/core/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Define custom exceptions and warnings for the client."""
2+
from __future__ import annotations
3+
4+
5+
class MPRestError(Exception):
6+
"""Raised when the query has problems, e.g., bad query format."""
7+
8+
9+
class MPRestWarning(Warning):
10+
"""Raised when a query is malformed but interpretable."""

mp_api/client/core/settings.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from multiprocessing import cpu_count
33
from typing import List
44

5-
from pydantic import Field
5+
from pydantic import Field, field_validator
66
from pydantic_settings import BaseSettings, SettingsConfigDict
77
from pymatgen.core import _load_pmg_settings
88

@@ -14,6 +14,7 @@
1414
_MUTE_PROGRESS_BAR = PMG_SETTINGS.get("MPRESTER_MUTE_PROGRESS_BARS", False)
1515
_MAX_HTTP_URL_LENGTH = PMG_SETTINGS.get("MPRESTER_MAX_HTTP_URL_LENGTH", 2000)
1616
_MAX_LIST_LENGTH = min(PMG_SETTINGS.get("MPRESTER_MAX_LIST_LENGTH", 10000), 10000)
17+
_DEFAULT_ENDPOINT = "https://api.materialsproject.org/"
1718

1819
try:
1920
CPU_COUNT = cpu_count()
@@ -80,11 +81,21 @@ class MAPIClientSettings(BaseSettings):
8081
)
8182

8283
MIN_EMMET_VERSION: str = Field(
83-
"0.54.0", description="Minimum compatible version of emmet-core for the client."
84+
"0.86.3rc0",
85+
description="Minimum compatible version of emmet-core for the client.",
8486
)
8587

8688
MAX_LIST_LENGTH: int = Field(
8789
_MAX_LIST_LENGTH, description="Maximum length of query parameter list"
8890
)
8991

92+
ENDPOINT: str = Field(
93+
_DEFAULT_ENDPOINT, description="The default API endpoint to use."
94+
)
95+
9096
model_config = SettingsConfigDict(env_prefix="MPRESTER_")
97+
98+
@field_validator("ENDPOINT", mode="before")
99+
def _get_endpoint_from_env(cls, v: str | None) -> str:
100+
"""Support setting endpoint via MP_API_ENDPOINT environment variable."""
101+
return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT

mp_api/client/core/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from monty.json import MontyDecoder
1010
from packaging.version import parse as parse_version
1111

12+
from mp_api.client.core.exceptions import MPRestError
1213
from mp_api.client.core.settings import MAPIClientSettings
1314

1415
if TYPE_CHECKING:
@@ -54,7 +55,7 @@ def load_json(
5455

5556

5657
def validate_api_key(api_key: str | None = None) -> str:
57-
"""Utility to find and pre-check validity of an API key."""
58+
"""Find and validate an API key."""
5859
# SETTINGS tries to read API key from ~/.config/.pmgrc.yaml
5960
api_key = api_key or os.getenv("MP_API_KEY")
6061
if not api_key:
@@ -64,7 +65,7 @@ def validate_api_key(api_key: str | None = None) -> str:
6465

6566
if not api_key or len(api_key) != 32:
6667
addendum = " Valid API keys are 32 characters." if api_key else ""
67-
raise ValueError(
68+
raise MPRestError(
6869
"Please obtain a valid API key from https://materialsproject.org/api "
6970
f"and export it as an environment variable `MP_API_KEY`.{addendum}"
7071
)
@@ -79,13 +80,13 @@ def validate_ids(id_list: list[str]) -> list[str]:
7980
id_list (List[str]): List of material or task IDs.
8081
8182
Raises:
82-
ValueError: If at least one ID is not formatted correctly.
83+
MPRestError: If at least one ID is not formatted correctly.
8384
8485
Returns:
8586
id_list: Returns original ID list if everything is formatted correctly.
8687
"""
87-
if len(id_list) > _MAPI_SETTINGS.MAX_LIST_LENGTH:
88-
raise ValueError(
88+
if len(id_list) > MAPIClientSettings().MAX_LIST_LENGTH:
89+
raise MPRestError(
8990
"List of material/molecule IDs provided is too long. Consider removing the ID filter to automatically pull"
9091
" data for all IDs and filter locally."
9192
)

0 commit comments

Comments
 (0)