55
66from __future__ import annotations
77
8+ import gzip
89import inspect
910import itertools
1011import os
1617from functools import cache
1718from importlib import import_module
1819from importlib .metadata import PackageNotFoundError , version
20+ from io import BytesIO
1921from json import JSONDecodeError
2022from math import ceil
2123from typing import TYPE_CHECKING , ForwardRef , Optional , get_args
22- from urllib .parse import quote , urljoin
24+ from urllib .parse import quote
2325
26+ import boto3
2427import requests
28+ from botocore import UNSIGNED
29+ from botocore .config import Config
30+ from botocore .exceptions import ClientError
2531from emmet .core .utils import jsanitize
2632from pydantic import BaseModel , create_model
2733from requests .adapters import HTTPAdapter
2834from requests .exceptions import RequestException
29- from smart_open import open
3035from tqdm .auto import tqdm
3136from urllib3 .util .retry import Retry
3237
3338from mp_api .client .core .exceptions import MPRestError
34- from mp_api .client .core .settings import MAPIClientSettings
35- from mp_api .client .core .utils import load_json , validate_ids
36-
37- try :
38- import boto3
39- from botocore import UNSIGNED
40- from botocore .config import Config
41- except ImportError :
42- boto3 = None
39+ from mp_api .client .core .settings import MAPI_CLIENT_SETTINGS
40+ from mp_api .client .core .utils import (
41+ load_json ,
42+ validate_api_key ,
43+ validate_endpoint ,
44+ validate_ids ,
45+ )
4346
4447try :
4548 import flask
5154
5255 from pydantic .fields import FieldInfo
5356
57+ from mp_api .client .core .utils import LazyImport
58+
5459try :
5560 __version__ = version ("mp_api" )
5661except PackageNotFoundError : # pragma: no cover
5762 __version__ = os .getenv ("SETUPTOOLS_SCM_PRETEND_VERSION" )
5863
5964
60- SETTINGS = MAPIClientSettings () # type: ignore
61-
62-
6365class _DictLikeAccess (BaseModel ):
6466 """Define a pydantic mix-in which permits dict-like access to model fields."""
6567
@@ -82,7 +84,6 @@ class BaseRester:
8284
8385 suffix : str = ""
8486 document_model : type [BaseModel ] | None = None
85- supports_versions : bool = False
8687 primary_key : str = "material_id"
8788
8889 def __init__ (
@@ -96,7 +97,7 @@ def __init__(
9697 use_document_model : bool = True ,
9798 timeout : int = 20 ,
9899 headers : dict | None = None ,
99- mute_progress_bars : bool = SETTINGS .MUTE_PROGRESS_BARS ,
100+ mute_progress_bars : bool = MAPI_CLIENT_SETTINGS .MUTE_PROGRESS_BARS ,
100101 ** kwargs ,
101102 ):
102103 """Initialize the REST API helper class.
@@ -130,23 +131,17 @@ def __init__(
130131 mute_progress_bars: Whether to disable progress bars.
131132 **kwargs: access to legacy kwargs that may be in the process of being deprecated
132133 """
133- # TODO: think about how to migrate from PMG_MAPI_KEY
134- self .api_key = api_key or os .getenv ("MP_API_KEY" )
135- self .base_endpoint = self .endpoint = endpoint or os .getenv (
136- "MP_API_ENDPOINT" , "https://api.materialsproject.org/"
137- )
134+ self .api_key = validate_api_key (api_key )
135+ self .base_endpoint = validate_endpoint (endpoint )
136+ self .endpoint = validate_endpoint (endpoint , suffix = self .suffix )
137+
138138 self .debug = debug
139139 self .include_user_agent = include_user_agent
140140 self .use_document_model = use_document_model
141141 self .timeout = timeout
142142 self .headers = headers or {}
143143 self .mute_progress_bars = mute_progress_bars
144- self .db_version = BaseRester ._get_database_version (self .endpoint )
145-
146- if self .suffix :
147- self .endpoint = urljoin (self .endpoint , self .suffix )
148- if not self .endpoint .endswith ("/" ):
149- self .endpoint += "/"
144+ self .db_version = BaseRester ._get_database_version (self .base_endpoint )
150145
151146 self ._session = session
152147 self ._s3_client = s3_client
@@ -167,13 +162,6 @@ def session(self) -> requests.Session:
167162
168163 @property
169164 def s3_client (self ):
170- if boto3 is None :
171- raise MPRestError (
172- "boto3 not installed. To query charge density, "
173- "band structure, or density of states data first "
174- "install with: 'pip install boto3'"
175- )
176-
177165 if not self ._s3_client :
178166 self ._s3_client = boto3 .client (
179167 "s3" ,
@@ -194,15 +182,14 @@ def _create_session(api_key, include_user_agent, headers):
194182 user_agent = f"{ mp_api_info } ({ python_info } { platform_info } )"
195183 session .headers ["user-agent" ] = user_agent
196184
197- settings = MAPIClientSettings () # type: ignore
198- max_retry_num = settings .MAX_RETRIES
185+ max_retry_num = MAPI_CLIENT_SETTINGS .MAX_RETRIES
199186 retry = Retry (
200187 total = max_retry_num ,
201188 read = max_retry_num ,
202189 connect = max_retry_num ,
203190 respect_retry_after_header = True ,
204191 status_forcelist = [429 , 504 , 502 ], # rate limiting
205- backoff_factor = settings .BACKOFF_FACTOR ,
192+ backoff_factor = MAPI_CLIENT_SETTINGS .BACKOFF_FACTOR ,
206193 )
207194 adapter = HTTPAdapter (max_retries = retry )
208195 session .mount ("http://" , adapter )
@@ -263,11 +250,7 @@ def _post_resource(
263250 payload = jsanitize (body )
264251
265252 try :
266- url = self .endpoint
267- if suburl :
268- url = urljoin (self .endpoint , suburl )
269- if not url .endswith ("/" ):
270- url += "/"
253+ url = validate_endpoint (self .endpoint , suffix = suburl )
271254 response = self .session .post (url , json = payload , verify = True , params = params )
272255
273256 if response .status_code == 200 :
@@ -331,11 +314,7 @@ def _patch_resource(
331314 payload = jsanitize (body )
332315
333316 try :
334- url = self .endpoint
335- if suburl :
336- url = urljoin (self .endpoint , suburl )
337- if not url .endswith ("/" ):
338- url += "/"
317+ url = validate_endpoint (self .endpoint , suffix = suburl )
339318 response = self .session .patch (url , json = payload , verify = True , params = params )
340319
341320 if response .status_code == 200 :
@@ -390,20 +369,31 @@ def _query_open_data(
390369 Returns:
391370 dict: MontyDecoded data
392371 """
393- decoder = decoder or load_json
372+ try :
373+ byio = BytesIO ()
374+ self .s3_client .download_fileobj (bucket , key , byio )
375+ byio .seek (0 )
376+ if (file_data := byio .read ()).startswith (b"\x1f \x8b " ):
377+ file_data = gzip .decompress (file_data )
378+ byio .close ()
394379
395- file = open (
396- f"s3://{ bucket } /{ key } " ,
397- encoding = "utf-8" ,
398- transport_params = {"client" : self .s3_client },
399- )
380+ decoder = decoder or load_json
400381
401- if "jsonl" in key :
402- decoded_data = [decoder (jline ) for jline in file .read ().splitlines ()]
403- else :
404- decoded_data = decoder (file .read ())
405- if not isinstance (decoded_data , list ):
406- decoded_data = [decoded_data ]
382+ if "jsonl" in key :
383+ decoded_data = [decoder (jline ) for jline in file_data .splitlines ()]
384+ else :
385+ decoded_data = decoder (file_data )
386+ if not isinstance (decoded_data , list ):
387+ decoded_data = [decoded_data ]
388+
389+ raise_error = not decoded_data or len (decoded_data ) == 0
390+
391+ except ClientError :
392+ # No such object exists
393+ raise_error = True
394+
395+ if raise_error :
396+ raise MPRestError (f"No object found: s3://{ bucket } /{ key } " )
407397
408398 return decoded_data , len (decoded_data ) # type: ignore
409399
@@ -467,14 +457,9 @@ def _query_resource(
467457 criteria ["_fields" ] = "," .join (fields )
468458
469459 try :
470- url = self .endpoint
471- if suburl :
472- url = urljoin (self .endpoint , suburl )
473- if not url .endswith ("/" ):
474- url += "/"
460+ url = validate_endpoint (self .endpoint , suffix = suburl )
475461
476462 if query_s3 :
477- db_version = self .db_version .replace ("." , "-" )
478463 if "/" not in self .suffix :
479464 suffix = self .suffix
480465 elif self .suffix == "molecules/summary" :
@@ -490,7 +475,7 @@ def _query_resource(
490475 bucket_suffix , prefix = "parsed" , "tasks_atomate2"
491476 else :
492477 bucket_suffix = "build"
493- prefix = f"collections/{ db_version } /{ suffix } "
478+ prefix = f"collections/{ self . db_version . replace ( '.' , '-' ) } /{ suffix } "
494479
495480 bucket = f"materialsproject-{ bucket_suffix } "
496481 paginator = self .s3_client .get_paginator ("list_objects_v2" )
@@ -618,15 +603,15 @@ def _submit_requests( # noqa
618603
619604 bare_url_len = len (url_string )
620605 max_param_str_length = (
621- MAPIClientSettings () .MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore
606+ MAPI_CLIENT_SETTINGS .MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore
622607 )
623608
624609 # Next, check if default number of parallel requests works.
625610 # If not, make slice size the minimum number of param entries
626611 # contained in any substring of length max_param_str_length.
627612 param_length = len (criteria [parallel_param ].split ("," ))
628613 slice_size = (
629- int (param_length / MAPIClientSettings () .NUM_PARALLEL_REQUESTS ) or 1 # type: ignore
614+ int (param_length / MAPI_CLIENT_SETTINGS .NUM_PARALLEL_REQUESTS ) or 1 # type: ignore
630615 )
631616
632617 url_param_string = quote (criteria [parallel_param ])
@@ -687,7 +672,7 @@ def _submit_requests( # noqa
687672 new_limits = [chunk_size ]
688673
689674 total_num_docs = 0
690- total_data = {"data" : []} # type: dict
675+ total_data : dict [ str , list [ Any ]] = {"data" : []}
691676
692677 # Obtain first page of results and get pagination information.
693678 # Individual total document limits (subtotal) will potentially
@@ -907,14 +892,14 @@ def _multi_thread(
907892 params_ind = 0
908893
909894 with ThreadPoolExecutor (
910- max_workers = MAPIClientSettings () .NUM_PARALLEL_REQUESTS # type: ignore
895+ max_workers = MAPI_CLIENT_SETTINGS .NUM_PARALLEL_REQUESTS # type: ignore
911896 ) as executor :
912897 # Get list of initial futures defined by max number of parallel requests
913898 futures = set ()
914899
915900 for params in itertools .islice (
916901 params_gen ,
917- MAPIClientSettings () .NUM_PARALLEL_REQUESTS , # type: ignore
902+ MAPI_CLIENT_SETTINGS .NUM_PARALLEL_REQUESTS , # type: ignore
918903 ):
919904 future = executor .submit (
920905 func ,
@@ -1276,7 +1261,7 @@ def _get_all_documents(
12761261 for key , entry in query_params .items ()
12771262 if isinstance (entry , str )
12781263 and len (entry .split ("," )) > 0
1279- and key not in MAPIClientSettings () .QUERY_NO_PARALLEL # type: ignore
1264+ and key not in MAPI_CLIENT_SETTINGS .QUERY_NO_PARALLEL # type: ignore
12801265 ),
12811266 key = lambda item : item [1 ],
12821267 reverse = True ,
@@ -1351,3 +1336,37 @@ def __str__(self): # pragma: no cover
13511336 f"{ self .__class__ .__name__ } connected to { self .endpoint } \n \n "
13521337 f"Available fields: { ', ' .join (self .available_fields )} \n \n "
13531338 )
1339+
1340+
1341+ class CoreRester (BaseRester ):
1342+ """Define a BaseRester with extra features for core resters.
1343+
1344+ Enables lazy importing / initialization of sub resters
1345+ provided in `_sub_resters`, which should be a map
1346+ of endpoints names to LazyImport objects.
1347+
1348+ """
1349+
1350+ _sub_resters : dict [str , LazyImport ] = {}
1351+
1352+ def __init__ (self , ** kwargs ):
1353+ """Ensure that sub resters are unset on re-init."""
1354+ super ().__init__ (** kwargs )
1355+ self .sub_resters = {k : v .copy () for k , v in self ._sub_resters .items ()}
1356+
1357+ def __getattr__ (self , v : str ):
1358+ if v in self .sub_resters :
1359+ if self .sub_resters [v ]._obj is None :
1360+ self .sub_resters [v ](
1361+ api_key = self .api_key ,
1362+ endpoint = self .base_endpoint ,
1363+ include_user_agent = self ._include_user_agent ,
1364+ session = self .session ,
1365+ use_document_model = self .use_document_model ,
1366+ headers = self .headers ,
1367+ mute_progress_bars = self .mute_progress_bars ,
1368+ )
1369+ return self .sub_resters [v ]
1370+
1371+ def __dir__ (self ):
1372+ return dir (self .__class__ ) + list (self ._sub_resters )
0 commit comments