Skip to content

Commit 30adeba

Browse files
Remove monty decoding (#1047)
2 parents 2478637 + d40cb19 commit 30adeba

23 files changed

Lines changed: 150 additions & 102 deletions

mp_api/client/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
22

3-
from .client import BaseRester, MPRestError, MPRestWarning
3+
from .client import BaseRester
4+
from .exceptions import MPRestError, MPRestWarning
45
from .settings import MAPIClientSettings

mp_api/client/core/client.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tqdm.auto import tqdm
3131
from urllib3.util.retry import Retry
3232

33+
from mp_api.client.core.exceptions import MPRestError
3334
from mp_api.client.core.settings import MAPIClientSettings
3435
from mp_api.client.core.utils import load_json, validate_ids
3536

@@ -92,11 +93,11 @@ def __init__(
9293
session: requests.Session | None = None,
9394
s3_client: Any | None = None,
9495
debug: bool = False,
95-
monty_decode: bool = True,
9696
use_document_model: bool = True,
9797
timeout: int = 20,
9898
headers: dict | None = None,
9999
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
100+
**kwargs,
100101
):
101102
"""Initialize the REST API helper class.
102103
@@ -121,13 +122,13 @@ def __init__(
121122
advanced usage only.
122123
s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores.
123124
debug: if True, print the URL for every request
124-
monty_decode: Decode the data using monty into python objects
125125
use_document_model: If False, skip the creating the document model and return data
126126
as a dictionary. This can be simpler to work with but bypasses data validation
127127
and will not give auto-complete for available fields.
128128
timeout: Time in seconds to wait until a request timeout error is thrown
129129
headers: Custom headers for localhost connections.
130130
mute_progress_bars: Whether to disable progress bars.
131+
**kwargs: access to legacy kwargs that may be in the process of being deprecated
131132
"""
132133
# TODO: think about how to migrate from PMG_MAPI_KEY
133134
self.api_key = api_key or os.getenv("MP_API_KEY")
@@ -136,7 +137,6 @@ def __init__(
136137
)
137138
self.debug = debug
138139
self.include_user_agent = include_user_agent
139-
self.monty_decode = monty_decode
140140
self.use_document_model = use_document_model
141141
self.timeout = timeout
142142
self.headers = headers or {}
@@ -151,6 +151,12 @@ def __init__(
151151
self._session = session
152152
self._s3_client = s3_client
153153

154+
if "monty_decode" in kwargs:
155+
warnings.warn(
156+
"Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`."
157+
"The client by default returns results consistent with `monty_decode=True`."
158+
)
159+
154160
@property
155161
def session(self) -> requests.Session:
156162
if not self._session:
@@ -265,7 +271,7 @@ def _post_resource(
265271
response = self.session.post(url, json=payload, verify=True, params=params)
266272

267273
if response.status_code == 200:
268-
data = load_json(response.text, deser=self.monty_decode)
274+
data = load_json(response.text)
269275
if self.document_model and use_document_model:
270276
if isinstance(data["data"], dict):
271277
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
@@ -333,7 +339,7 @@ def _patch_resource(
333339
response = self.session.patch(url, json=payload, verify=True, params=params)
334340

335341
if response.status_code == 200:
336-
data = load_json(response.text, deser=self.monty_decode)
342+
data = load_json(response.text)
337343
if self.document_model and use_document_model:
338344
if isinstance(data["data"], dict):
339345
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
@@ -384,10 +390,7 @@ def _query_open_data(
384390
Returns:
385391
dict: MontyDecoded data
386392
"""
387-
if not decoder:
388-
389-
def decoder(x):
390-
return load_json(x, deser=self.monty_decode)
393+
decoder = decoder or load_json
391394

392395
file = open(
393396
f"s3://{bucket}/{key}",
@@ -997,7 +1000,7 @@ def _submit_request_and_process(
9971000
)
9981001

9991002
if response.status_code == 200:
1000-
data = load_json(response.text, deser=self.monty_decode)
1003+
data = load_json(response.text)
10011004
# other sub-urls may use different document models
10021005
# the client does not handle this in a particularly smart way currently
10031006
if self.document_model and use_document_model:
@@ -1302,12 +1305,10 @@ def count(self, criteria: dict | None = None) -> int | str:
13021305
"""
13031306
criteria = criteria or {}
13041307
user_preferences = (
1305-
self.monty_decode,
13061308
self.use_document_model,
13071309
self.mute_progress_bars,
13081310
)
1309-
self.monty_decode, self.use_document_model, self.mute_progress_bars = (
1310-
False,
1311+
self.use_document_model, self.mute_progress_bars = (
13111312
False,
13121313
True,
13131314
) # do not waste cycles decoding
@@ -1329,7 +1330,6 @@ def count(self, criteria: dict | None = None) -> int | str:
13291330
)
13301331

13311332
(
1332-
self.monty_decode,
13331333
self.use_document_model,
13341334
self.mute_progress_bars,
13351335
) = user_preferences
@@ -1351,11 +1351,3 @@ def __str__(self): # pragma: no cover
13511351
f"{self.__class__.__name__} connected to {self.endpoint}\n\n"
13521352
f"Available fields: {', '.join(self.available_fields)}\n\n"
13531353
)
1354-
1355-
1356-
class MPRestError(Exception):
1357-
"""Raised when the query has problems, e.g., bad query format."""
1358-
1359-
1360-
class MPRestWarning(Warning):
1361-
"""Raised when a query is malformed but interpretable."""

mp_api/client/core/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Define custom exceptions and warnings for the client."""
2+
from __future__ import annotations
3+
4+
5+
class MPRestError(Exception):
6+
"""Raised when the query has problems, e.g., bad query format."""
7+
8+
9+
class MPRestWarning(Warning):
10+
"""Raised when a query is malformed but interpretable."""

mp_api/client/core/settings.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from multiprocessing import cpu_count
33
from typing import List
44

5-
from pydantic import Field
5+
from pydantic import Field, field_validator
66
from pydantic_settings import BaseSettings, SettingsConfigDict
77
from pymatgen.core import _load_pmg_settings
88

@@ -14,6 +14,7 @@
1414
_MUTE_PROGRESS_BAR = PMG_SETTINGS.get("MPRESTER_MUTE_PROGRESS_BARS", False)
1515
_MAX_HTTP_URL_LENGTH = PMG_SETTINGS.get("MPRESTER_MAX_HTTP_URL_LENGTH", 2000)
1616
_MAX_LIST_LENGTH = min(PMG_SETTINGS.get("MPRESTER_MAX_LIST_LENGTH", 10000), 10000)
17+
_DEFAULT_ENDPOINT = "https://api.materialsproject.org/"
1718

1819
try:
1920
CPU_COUNT = cpu_count()
@@ -80,11 +81,21 @@ class MAPIClientSettings(BaseSettings):
8081
)
8182

8283
MIN_EMMET_VERSION: str = Field(
83-
"0.54.0", description="Minimum compatible version of emmet-core for the client."
84+
"0.86.3rc0",
85+
description="Minimum compatible version of emmet-core for the client.",
8486
)
8587

8688
MAX_LIST_LENGTH: int = Field(
8789
_MAX_LIST_LENGTH, description="Maximum length of query parameter list"
8890
)
8991

92+
ENDPOINT: str = Field(
93+
_DEFAULT_ENDPOINT, description="The default API endpoint to use."
94+
)
95+
9096
model_config = SettingsConfigDict(env_prefix="MPRESTER_")
97+
98+
@field_validator("ENDPOINT", mode="before")
99+
def _get_endpoint_from_env(cls, v: str | None) -> str:
100+
"""Support setting endpoint via MP_API_ENDPOINT environment variable."""
101+
return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT

mp_api/client/core/utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from typing import TYPE_CHECKING, Literal
45

56
import orjson
@@ -8,6 +9,7 @@
89
from monty.json import MontyDecoder
910
from packaging.version import parse as parse_version
1011

12+
from mp_api.client.core.exceptions import MPRestError
1113
from mp_api.client.core.settings import MAPIClientSettings
1214

1315
if TYPE_CHECKING:
@@ -50,20 +52,39 @@ def load_json(
5052
return MontyDecoder().process_decoded(data) if deser else data
5153

5254

55+
def validate_api_key(api_key: str | None = None) -> str:
56+
"""Find and validate an API key."""
57+
# SETTINGS tries to read API key from ~/.config/.pmgrc.yaml
58+
api_key = api_key or os.getenv("MP_API_KEY")
59+
if not api_key:
60+
from pymatgen.core import SETTINGS
61+
62+
api_key = SETTINGS.get("PMG_MAPI_KEY")
63+
64+
if not api_key or len(api_key) != 32:
65+
addendum = " Valid API keys are 32 characters." if api_key else ""
66+
raise MPRestError(
67+
"Please obtain a valid API key from https://materialsproject.org/api "
68+
f"and export it as an environment variable `MP_API_KEY`.{addendum}"
69+
)
70+
71+
return api_key
72+
73+
5374
def validate_ids(id_list: list[str]) -> list[str]:
5475
"""Function to validate material and task IDs.
5576
5677
Args:
5778
id_list (List[str]): List of material or task IDs.
5879
5980
Raises:
60-
ValueError: If at least one ID is not formatted correctly.
81+
MPRestError: If at least one ID is not formatted correctly.
6182
6283
Returns:
6384
id_list: Returns original ID list if everything is formatted correctly.
6485
"""
6586
if len(id_list) > MAPIClientSettings().MAX_LIST_LENGTH:
66-
raise ValueError(
87+
raise MPRestError(
6788
"List of material/molecule IDs provided is too long. Consider removing the ID filter to automatically pull"
6889
" data for all IDs and filter locally."
6990
)

0 commit comments

Comments
 (0)