Skip to content

Commit 3f88c3c

Browse files
merge conflicts
2 parents 6d5b688 + 84a3f1e commit 3f88c3c

25 files changed

Lines changed: 391 additions & 655 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,4 @@ ENV/
120120
_autosummary
121121

122122
uv.lock
123-
JANAF_O2_data.json
123+
JANAF_*_data.json

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
maggma Copyright (c) 2017, The Regents of the University of
1+
Copyright (c) 2017, The Regents of the University of
22
California, through Lawrence Berkeley National Laboratory (subject
33
to receipt of any required approvals from the U.S. Dept. of Energy).
44
All rights reserved.

mp_api/client/core/client.py

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import inspect
99
import itertools
10-
import json
1110
import os
1211
import platform
1312
import sys
@@ -23,15 +22,14 @@
2322
TYPE_CHECKING,
2423
ForwardRef,
2524
Generic,
25+
Optional,
2626
TypeVar,
2727
get_args,
2828
)
2929
from urllib.parse import quote, urljoin
3030

3131
import requests
32-
from bson import json_util
3332
from emmet.core.utils import jsanitize
34-
from monty.json import MontyDecoder
3533
from pydantic import BaseModel, create_model
3634
from requests.adapters import HTTPAdapter
3735
from requests.exceptions import RequestException
@@ -40,7 +38,7 @@
4038
from urllib3.util.retry import Retry
4139

4240
from mp_api.client.core.settings import MAPIClientSettings
43-
from mp_api.client.core.utils import api_sanitize, validate_ids
41+
from mp_api.client.core.utils import load_json, validate_ids
4442

4543
try:
4644
import boto3
@@ -57,6 +55,8 @@
5755
if TYPE_CHECKING:
5856
from typing import Any, Callable
5957

58+
from pydantic.fields import FieldInfo
59+
6060
try:
6161
__version__ = version("mp_api")
6262
except PackageNotFoundError: # pragma: no cover
@@ -150,12 +150,6 @@ def __init__(
150150
else:
151151
self._s3_client = None
152152

153-
self.document_model = (
154-
api_sanitize(self.document_model) # type: ignore
155-
if self.document_model is not None
156-
else None # type: ignore
157-
)
158-
159153
@property
160154
def session(self) -> requests.Session:
161155
if not self._session:
@@ -270,11 +264,7 @@ def _post_resource(
270264
response = self.session.post(url, json=payload, verify=True, params=params)
271265

272266
if response.status_code == 200:
273-
if self.monty_decode:
274-
data = json.loads(response.text, cls=MontyDecoder)
275-
else:
276-
data = json.loads(response.text)
277-
267+
data = load_json(response.text, deser=self.monty_decode)
278268
if self.document_model and use_document_model:
279269
if isinstance(data["data"], dict):
280270
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
@@ -287,7 +277,7 @@ def _post_resource(
287277

288278
else:
289279
try:
290-
data = json.loads(response.text)["detail"]
280+
data = load_json(response.text)["detail"]
291281
except (JSONDecodeError, KeyError):
292282
data = f"Response {response.text}"
293283
if isinstance(data, str):
@@ -342,11 +332,7 @@ def _patch_resource(
342332
response = self.session.patch(url, json=payload, verify=True, params=params)
343333

344334
if response.status_code == 200:
345-
if self.monty_decode:
346-
data = json.loads(response.text, cls=MontyDecoder)
347-
else:
348-
data = json.loads(response.text)
349-
335+
data = load_json(response.text, deser=self.monty_decode)
350336
if self.document_model and use_document_model:
351337
if isinstance(data["data"], dict):
352338
data["data"] = self.document_model.model_validate(data["data"]) # type: ignore
@@ -359,7 +345,7 @@ def _patch_resource(
359345

360346
else:
361347
try:
362-
data = json.loads(response.text)["detail"]
348+
data = load_json(response.text)["detail"]
363349
except (JSONDecodeError, KeyError):
364350
data = f"Response {response.text}"
365351
if isinstance(data, str):
@@ -384,18 +370,24 @@ def _query_open_data(
384370
self,
385371
bucket: str,
386372
key: str,
387-
decoder: Callable,
373+
decoder: Callable | None = None,
388374
) -> tuple[list[dict] | list[bytes], int]:
389375
"""Query and deserialize Materials Project AWS open data s3 buckets.
390376
391377
Args:
392378
bucket (str): Materials project bucket name
393379
key (str): Key for file including all prefixes
394-
decoder(Callable): Callable used to deserialize data
380+
decoder(Callable or None): Callable used to deserialize data.
381+
Defaults to mp_api.core.utils.load_json
395382
396383
Returns:
397384
dict: MontyDecoded data
398385
"""
386+
if not decoder:
387+
388+
def decoder(x):
389+
return load_json(x, deser=self.monty_decode)
390+
399391
file = open(
400392
f"s3://{bucket}/{key}",
401393
encoding="utf-8",
@@ -527,16 +519,11 @@ def _query_resource(
527519
"Ignoring `fields` argument: All fields are always included when no query is provided."
528520
)
529521

530-
decoder = (
531-
MontyDecoder().decode if self.monty_decode else json_util.loads
532-
)
533-
534522
# Multithreaded function inputs
535523
s3_params_list = {
536524
key: {
537525
"bucket": bucket,
538526
"key": key,
539-
"decoder": decoder,
540527
}
541528
for key in keys
542529
}
@@ -1013,11 +1000,7 @@ def _submit_request_and_process(
10131000
)
10141001

10151002
if response.status_code == 200:
1016-
if self.monty_decode:
1017-
data = json.loads(response.text, cls=MontyDecoder)
1018-
else:
1019-
data = json.loads(response.text)
1020-
1003+
data = load_json(response.text, deser=self.monty_decode)
10211004
# other sub-urls may use different document models
10221005
# the client does not handle this in a particularly smart way currently
10231006
if self.document_model and use_document_model:
@@ -1029,7 +1012,7 @@ def _submit_request_and_process(
10291012

10301013
else:
10311014
try:
1032-
data = json.loads(response.text)["detail"]
1015+
data = load_json(response.text)["detail"]
10331016
except (JSONDecodeError, KeyError):
10341017
data = f"Response {response.text}"
10351018
if isinstance(data, str):
@@ -1057,10 +1040,8 @@ def _convert_to_model(self, data: list[dict]):
10571040
(list[MPDataDoc]): List of MPDataDoc objects
10581041
10591042
"""
1060-
raw_doc_list = [self.document_model.model_validate(d) for d in data] # type: ignore
1061-
1062-
if len(raw_doc_list) > 0:
1063-
data_model, set_fields, _ = self._generate_returned_model(raw_doc_list[0])
1043+
if len(data) > 0:
1044+
data_model, set_fields, _ = self._generate_returned_model(data[0])
10641045

10651046
data = [
10661047
data_model(
@@ -1070,44 +1051,56 @@ def _convert_to_model(self, data: list[dict]):
10701051
if field in set_fields
10711052
}
10721053
)
1073-
for raw_doc in raw_doc_list
1054+
for raw_doc in data
10741055
]
10751056

10761057
return data
10771058

1078-
def _generate_returned_model(self, doc):
1059+
def _generate_returned_model(
1060+
self, doc: dict[str, Any]
1061+
) -> tuple[BaseModel, list[str], list[str]]:
10791062
model_fields = self.document_model.model_fields
1080-
1081-
set_fields = doc.model_fields_set
1063+
set_fields = [k for k in doc if k in model_fields]
10821064
unset_fields = [field for field in model_fields if field not in set_fields]
10831065

10841066
# Update with locals() from external module if needed
1085-
other_vars = {}
10861067
if any(
1068+
isinstance(field_meta.annotation, ForwardRef)
1069+
for field_meta in model_fields.values()
1070+
) or any(
10871071
isinstance(typ, ForwardRef)
10881072
for field_meta in model_fields.values()
10891073
for typ in get_args(field_meta.annotation)
10901074
):
1091-
other_vars = vars(import_module(self.document_model.__module__))
1092-
1093-
include_fields = {
1094-
name: (
1095-
model_fields[name].annotation,
1096-
model_fields[name],
1075+
vars(import_module(self.document_model.__module__))
1076+
1077+
include_fields: dict[str, tuple[type, FieldInfo]] = {}
1078+
for name in set_fields:
1079+
field_copy = model_fields[name]._copy()
1080+
field_copy.default = None
1081+
include_fields[name] = (
1082+
Optional[model_fields[name].annotation],
1083+
field_copy,
10971084
)
1098-
for name in set_fields
1099-
}
11001085

11011086
data_model = create_model( # type: ignore
11021087
"MPDataDoc",
11031088
**include_fields,
11041089
# TODO fields_not_requested is not the same as unset_fields
11051090
# i.e. field could be requested but not available in the raw doc
11061091
fields_not_requested=(list[str], unset_fields),
1107-
__base__=self.document_model,
1092+
__doc__=".".join(
1093+
[
1094+
getattr(self.document_model, k, "")
1095+
for k in ("__module__", "__name__")
1096+
]
1097+
),
1098+
__module__=self.document_model.__module__,
11081099
)
1109-
if other_vars:
1110-
data_model.model_rebuild(_types_namespace=other_vars)
1100+
# if other_vars:
1101+
# data_model.model_rebuild(_types_namespace=other_vars)
1102+
1103+
orig_rester_name = self.document_model.__name__
11111104

11121105
def new_repr(self) -> str:
11131106
extra = ",\n".join(
@@ -1116,7 +1109,7 @@ def new_repr(self) -> str:
11161109
if n == "fields_not_requested" or n in set_fields
11171110
)
11181111

1119-
s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
1112+
s = f"\033[4m\033[1m{self.__class__.__name__}<{orig_rester_name}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
11201113
return s
11211114

11221115
def new_str(self) -> str:
@@ -1230,8 +1223,14 @@ def get_data_by_id(
12301223
stacklevel=2,
12311224
)
12321225

1233-
if self.primary_key in ["material_id", "task_id"]:
1234-
validate_ids([document_id])
1226+
if self.primary_key in [
1227+
"material_id",
1228+
"task_id",
1229+
"battery_id",
1230+
"spectrum_id",
1231+
"thermo_id",
1232+
]:
1233+
document_id = validate_ids([document_id])[0]
12351234

12361235
if isinstance(fields, str): # pragma: no cover
12371236
fields = (fields,) # type: ignore

0 commit comments

Comments
 (0)