77
88import inspect
99import itertools
10- import json
1110import os
1211import platform
1312import sys
2322 TYPE_CHECKING ,
2423 ForwardRef ,
2524 Generic ,
25+ Optional ,
2626 TypeVar ,
2727 get_args ,
2828)
2929from urllib .parse import quote , urljoin
3030
3131import requests
32- from bson import json_util
3332from emmet .core .utils import jsanitize
34- from monty .json import MontyDecoder
3533from pydantic import BaseModel , create_model
3634from requests .adapters import HTTPAdapter
3735from requests .exceptions import RequestException
4038from urllib3 .util .retry import Retry
4139
4240from mp_api .client .core .settings import MAPIClientSettings
43- from mp_api .client .core .utils import api_sanitize , validate_ids
41+ from mp_api .client .core .utils import load_json , validate_ids
4442
4543try :
4644 import boto3
5755if TYPE_CHECKING :
5856 from typing import Any , Callable
5957
58+ from pydantic .fields import FieldInfo
59+
6060try :
6161 __version__ = version ("mp_api" )
6262except PackageNotFoundError : # pragma: no cover
@@ -150,12 +150,6 @@ def __init__(
150150 else :
151151 self ._s3_client = None
152152
153- self .document_model = (
154- api_sanitize (self .document_model ) # type: ignore
155- if self .document_model is not None
156- else None # type: ignore
157- )
158-
159153 @property
160154 def session (self ) -> requests .Session :
161155 if not self ._session :
@@ -270,11 +264,7 @@ def _post_resource(
270264 response = self .session .post (url , json = payload , verify = True , params = params )
271265
272266 if response .status_code == 200 :
273- if self .monty_decode :
274- data = json .loads (response .text , cls = MontyDecoder )
275- else :
276- data = json .loads (response .text )
277-
267+ data = load_json (response .text , deser = self .monty_decode )
278268 if self .document_model and use_document_model :
279269 if isinstance (data ["data" ], dict ):
280270 data ["data" ] = self .document_model .model_validate (data ["data" ]) # type: ignore
@@ -287,7 +277,7 @@ def _post_resource(
287277
288278 else :
289279 try :
290- data = json . loads (response .text )["detail" ]
280+ data = load_json (response .text )["detail" ]
291281 except (JSONDecodeError , KeyError ):
292282 data = f"Response { response .text } "
293283 if isinstance (data , str ):
@@ -342,11 +332,7 @@ def _patch_resource(
342332 response = self .session .patch (url , json = payload , verify = True , params = params )
343333
344334 if response .status_code == 200 :
345- if self .monty_decode :
346- data = json .loads (response .text , cls = MontyDecoder )
347- else :
348- data = json .loads (response .text )
349-
335+ data = load_json (response .text , deser = self .monty_decode )
350336 if self .document_model and use_document_model :
351337 if isinstance (data ["data" ], dict ):
352338 data ["data" ] = self .document_model .model_validate (data ["data" ]) # type: ignore
@@ -359,7 +345,7 @@ def _patch_resource(
359345
360346 else :
361347 try :
362- data = json . loads (response .text )["detail" ]
348+ data = load_json (response .text )["detail" ]
363349 except (JSONDecodeError , KeyError ):
364350 data = f"Response { response .text } "
365351 if isinstance (data , str ):
@@ -384,18 +370,24 @@ def _query_open_data(
384370 self ,
385371 bucket : str ,
386372 key : str ,
387- decoder : Callable ,
373+ decoder : Callable | None = None ,
388374 ) -> tuple [list [dict ] | list [bytes ], int ]:
389375 """Query and deserialize Materials Project AWS open data s3 buckets.
390376
391377 Args:
392378 bucket (str): Materials project bucket name
393379 key (str): Key for file including all prefixes
394- decoder(Callable): Callable used to deserialize data
380+ decoder(Callable or None): Callable used to deserialize data.
381+ Defaults to mp_api.core.utils.load_json
395382
396383 Returns:
397384 dict: MontyDecoded data
398385 """
386+ if not decoder :
387+
388+ def decoder (x ):
389+ return load_json (x , deser = self .monty_decode )
390+
399391 file = open (
400392 f"s3://{ bucket } /{ key } " ,
401393 encoding = "utf-8" ,
@@ -527,16 +519,11 @@ def _query_resource(
527519 "Ignoring `fields` argument: All fields are always included when no query is provided."
528520 )
529521
530- decoder = (
531- MontyDecoder ().decode if self .monty_decode else json_util .loads
532- )
533-
534522 # Multithreaded function inputs
535523 s3_params_list = {
536524 key : {
537525 "bucket" : bucket ,
538526 "key" : key ,
539- "decoder" : decoder ,
540527 }
541528 for key in keys
542529 }
@@ -1013,11 +1000,7 @@ def _submit_request_and_process(
10131000 )
10141001
10151002 if response .status_code == 200 :
1016- if self .monty_decode :
1017- data = json .loads (response .text , cls = MontyDecoder )
1018- else :
1019- data = json .loads (response .text )
1020-
1003+ data = load_json (response .text , deser = self .monty_decode )
10211004 # other sub-urls may use different document models
10221005 # the client does not handle this in a particularly smart way currently
10231006 if self .document_model and use_document_model :
@@ -1029,7 +1012,7 @@ def _submit_request_and_process(
10291012
10301013 else :
10311014 try :
1032- data = json . loads (response .text )["detail" ]
1015+ data = load_json (response .text )["detail" ]
10331016 except (JSONDecodeError , KeyError ):
10341017 data = f"Response { response .text } "
10351018 if isinstance (data , str ):
@@ -1057,10 +1040,8 @@ def _convert_to_model(self, data: list[dict]):
10571040 (list[MPDataDoc]): List of MPDataDoc objects
10581041
10591042 """
1060- raw_doc_list = [self .document_model .model_validate (d ) for d in data ] # type: ignore
1061-
1062- if len (raw_doc_list ) > 0 :
1063- data_model , set_fields , _ = self ._generate_returned_model (raw_doc_list [0 ])
1043+ if len (data ) > 0 :
1044+ data_model , set_fields , _ = self ._generate_returned_model (data [0 ])
10641045
10651046 data = [
10661047 data_model (
@@ -1070,44 +1051,56 @@ def _convert_to_model(self, data: list[dict]):
10701051 if field in set_fields
10711052 }
10721053 )
1073- for raw_doc in raw_doc_list
1054+ for raw_doc in data
10741055 ]
10751056
10761057 return data
10771058
1078- def _generate_returned_model (self , doc ):
1059+ def _generate_returned_model (
1060+ self , doc : dict [str , Any ]
1061+ ) -> tuple [BaseModel , list [str ], list [str ]]:
10791062 model_fields = self .document_model .model_fields
1080-
1081- set_fields = doc .model_fields_set
1063+ set_fields = [k for k in doc if k in model_fields ]
10821064 unset_fields = [field for field in model_fields if field not in set_fields ]
10831065
10841066 # Update with locals() from external module if needed
1085- other_vars = {}
10861067 if any (
1068+ isinstance (field_meta .annotation , ForwardRef )
1069+ for field_meta in model_fields .values ()
1070+ ) or any (
10871071 isinstance (typ , ForwardRef )
10881072 for field_meta in model_fields .values ()
10891073 for typ in get_args (field_meta .annotation )
10901074 ):
1091- other_vars = vars (import_module (self .document_model .__module__ ))
1092-
1093- include_fields = {
1094- name : (
1095- model_fields [name ].annotation ,
1096- model_fields [name ],
1075+ vars (import_module (self .document_model .__module__ ))
1076+
1077+ include_fields : dict [str , tuple [type , FieldInfo ]] = {}
1078+ for name in set_fields :
1079+ field_copy = model_fields [name ]._copy ()
1080+ field_copy .default = None
1081+ include_fields [name ] = (
1082+ Optional [model_fields [name ].annotation ],
1083+ field_copy ,
10971084 )
1098- for name in set_fields
1099- }
11001085
11011086 data_model = create_model ( # type: ignore
11021087 "MPDataDoc" ,
11031088 ** include_fields ,
11041089 # TODO fields_not_requested is not the same as unset_fields
11051090 # i.e. field could be requested but not available in the raw doc
11061091 fields_not_requested = (list [str ], unset_fields ),
1107- __base__ = self .document_model ,
1092+ __doc__ = "." .join (
1093+ [
1094+ getattr (self .document_model , k , "" )
1095+ for k in ("__module__" , "__name__" )
1096+ ]
1097+ ),
1098+ __module__ = self .document_model .__module__ ,
11081099 )
1109- if other_vars :
1110- data_model .model_rebuild (_types_namespace = other_vars )
1100+ # if other_vars:
1101+ # data_model.model_rebuild(_types_namespace=other_vars)
1102+
1103+ orig_rester_name = self .document_model .__name__
11111104
11121105 def new_repr (self ) -> str :
11131106 extra = ",\n " .join (
@@ -1116,7 +1109,7 @@ def new_repr(self) -> str:
11161109 if n == "fields_not_requested" or n in set_fields
11171110 )
11181111
1119- s = f"\033 [4m\033 [1m{ self .__class__ .__name__ } <{ self . __class__ . __base__ . __name__ } >\033 [0;0m\033 [0;0m(\n { extra } \n )" # noqa: E501
1112+ s = f"\033 [4m\033 [1m{ self .__class__ .__name__ } <{ orig_rester_name } >\033 [0;0m\033 [0;0m(\n { extra } \n )" # noqa: E501
11201113 return s
11211114
11221115 def new_str (self ) -> str :
@@ -1230,8 +1223,14 @@ def get_data_by_id(
12301223 stacklevel = 2 ,
12311224 )
12321225
1233- if self .primary_key in ["material_id" , "task_id" ]:
1234- validate_ids ([document_id ])
1226+ if self .primary_key in [
1227+ "material_id" ,
1228+ "task_id" ,
1229+ "battery_id" ,
1230+ "spectrum_id" ,
1231+ "thermo_id" ,
1232+ ]:
1233+ document_id = validate_ids ([document_id ])[0 ]
12351234
12361235 if isinstance (fields , str ): # pragma: no cover
12371236 fields = (fields ,) # type: ignore
0 commit comments