Skip to content

Commit cf9bde6

Browse files
Lazy loading (#1044)
2 parents e1cd425 + d3a5bc2 commit cf9bde6

78 files changed

Lines changed: 771 additions & 3700 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/testing.yml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ concurrency:
1515
jobs:
1616
test:
1717
strategy:
18-
max-parallel: 2
1918
matrix:
2019
os: ["ubuntu-latest"]
2120
python-version: ["3.11", "3.12"]
@@ -64,13 +63,3 @@ jobs:
6463
with:
6564
token: ${{ secrets.CODECOV_TOKEN }}
6665
file: ./coverage.xml
67-
68-
auto-gen-release:
69-
needs: [test]
70-
runs-on: ubuntu-latest
71-
env:
72-
GITHUB_TOKEN: ${{ secrets.API_VER_BUMP_TOKEN }}
73-
steps:
74-
- uses: rymndhng/release-on-push-action@v0.20.0
75-
with:
76-
bump_version_scheme: norelease

mp_api/client/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
from .client import BaseRester
44
from .exceptions import MPRestError, MPRestWarning
5-
from .settings import MAPIClientSettings
5+
from .settings import MAPI_CLIENT_SETTINGS, MAPIClientSettings

mp_api/client/core/client.py

Lines changed: 91 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import gzip
89
import inspect
910
import itertools
1011
import os
@@ -16,30 +17,32 @@
1617
from functools import cache
1718
from importlib import import_module
1819
from importlib.metadata import PackageNotFoundError, version
20+
from io import BytesIO
1921
from json import JSONDecodeError
2022
from math import ceil
2123
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
22-
from urllib.parse import quote, urljoin
24+
from urllib.parse import quote
2325

26+
import boto3
2427
import requests
28+
from botocore import UNSIGNED
29+
from botocore.config import Config
30+
from botocore.exceptions import ClientError
2531
from emmet.core.utils import jsanitize
2632
from pydantic import BaseModel, create_model
2733
from requests.adapters import HTTPAdapter
2834
from requests.exceptions import RequestException
29-
from smart_open import open
3035
from tqdm.auto import tqdm
3136
from urllib3.util.retry import Retry
3237

3338
from mp_api.client.core.exceptions import MPRestError
34-
from mp_api.client.core.settings import MAPIClientSettings
35-
from mp_api.client.core.utils import load_json, validate_ids
36-
37-
try:
38-
import boto3
39-
from botocore import UNSIGNED
40-
from botocore.config import Config
41-
except ImportError:
42-
boto3 = None
39+
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
40+
from mp_api.client.core.utils import (
41+
load_json,
42+
validate_api_key,
43+
validate_endpoint,
44+
validate_ids,
45+
)
4346

4447
try:
4548
import flask
@@ -51,15 +54,14 @@
5154

5255
from pydantic.fields import FieldInfo
5356

57+
from mp_api.client.core.utils import LazyImport
58+
5459
try:
5560
__version__ = version("mp_api")
5661
except PackageNotFoundError: # pragma: no cover
5762
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION")
5863

5964

60-
SETTINGS = MAPIClientSettings() # type: ignore
61-
62-
6365
class _DictLikeAccess(BaseModel):
6466
"""Define a pydantic mix-in which permits dict-like access to model fields."""
6567

@@ -82,7 +84,6 @@ class BaseRester:
8284

8385
suffix: str = ""
8486
document_model: type[BaseModel] | None = None
85-
supports_versions: bool = False
8687
primary_key: str = "material_id"
8788

8889
def __init__(
@@ -96,7 +97,7 @@ def __init__(
9697
use_document_model: bool = True,
9798
timeout: int = 20,
9899
headers: dict | None = None,
99-
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
100+
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
100101
**kwargs,
101102
):
102103
"""Initialize the REST API helper class.
@@ -130,23 +131,17 @@ def __init__(
130131
mute_progress_bars: Whether to disable progress bars.
131132
**kwargs: access to legacy kwargs that may be in the process of being deprecated
132133
"""
133-
# TODO: think about how to migrate from PMG_MAPI_KEY
134-
self.api_key = api_key or os.getenv("MP_API_KEY")
135-
self.base_endpoint = self.endpoint = endpoint or os.getenv(
136-
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
137-
)
134+
self.api_key = validate_api_key(api_key)
135+
self.base_endpoint = validate_endpoint(endpoint)
136+
self.endpoint = validate_endpoint(endpoint, suffix=self.suffix)
137+
138138
self.debug = debug
139139
self.include_user_agent = include_user_agent
140140
self.use_document_model = use_document_model
141141
self.timeout = timeout
142142
self.headers = headers or {}
143143
self.mute_progress_bars = mute_progress_bars
144-
self.db_version = BaseRester._get_database_version(self.endpoint)
145-
146-
if self.suffix:
147-
self.endpoint = urljoin(self.endpoint, self.suffix)
148-
if not self.endpoint.endswith("/"):
149-
self.endpoint += "/"
144+
self.db_version = BaseRester._get_database_version(self.base_endpoint)
150145

151146
self._session = session
152147
self._s3_client = s3_client
@@ -167,13 +162,6 @@ def session(self) -> requests.Session:
167162

168163
@property
169164
def s3_client(self):
170-
if boto3 is None:
171-
raise MPRestError(
172-
"boto3 not installed. To query charge density, "
173-
"band structure, or density of states data first "
174-
"install with: 'pip install boto3'"
175-
)
176-
177165
if not self._s3_client:
178166
self._s3_client = boto3.client(
179167
"s3",
@@ -194,15 +182,14 @@ def _create_session(api_key, include_user_agent, headers):
194182
user_agent = f"{mp_api_info} ({python_info} {platform_info})"
195183
session.headers["user-agent"] = user_agent
196184

197-
settings = MAPIClientSettings() # type: ignore
198-
max_retry_num = settings.MAX_RETRIES
185+
max_retry_num = MAPI_CLIENT_SETTINGS.MAX_RETRIES
199186
retry = Retry(
200187
total=max_retry_num,
201188
read=max_retry_num,
202189
connect=max_retry_num,
203190
respect_retry_after_header=True,
204191
status_forcelist=[429, 504, 502], # rate limiting
205-
backoff_factor=settings.BACKOFF_FACTOR,
192+
backoff_factor=MAPI_CLIENT_SETTINGS.BACKOFF_FACTOR,
206193
)
207194
adapter = HTTPAdapter(max_retries=retry)
208195
session.mount("http://", adapter)
@@ -263,11 +250,7 @@ def _post_resource(
263250
payload = jsanitize(body)
264251

265252
try:
266-
url = self.endpoint
267-
if suburl:
268-
url = urljoin(self.endpoint, suburl)
269-
if not url.endswith("/"):
270-
url += "/"
253+
url = validate_endpoint(self.endpoint, suffix=suburl)
271254
response = self.session.post(url, json=payload, verify=True, params=params)
272255

273256
if response.status_code == 200:
@@ -331,11 +314,7 @@ def _patch_resource(
331314
payload = jsanitize(body)
332315

333316
try:
334-
url = self.endpoint
335-
if suburl:
336-
url = urljoin(self.endpoint, suburl)
337-
if not url.endswith("/"):
338-
url += "/"
317+
url = validate_endpoint(self.endpoint, suffix=suburl)
339318
response = self.session.patch(url, json=payload, verify=True, params=params)
340319

341320
if response.status_code == 200:
@@ -390,20 +369,31 @@ def _query_open_data(
390369
Returns:
391370
dict: MontyDecoded data
392371
"""
393-
decoder = decoder or load_json
372+
try:
373+
byio = BytesIO()
374+
self.s3_client.download_fileobj(bucket, key, byio)
375+
byio.seek(0)
376+
if (file_data := byio.read()).startswith(b"\x1f\x8b"):
377+
file_data = gzip.decompress(file_data)
378+
byio.close()
394379

395-
file = open(
396-
f"s3://{bucket}/{key}",
397-
encoding="utf-8",
398-
transport_params={"client": self.s3_client},
399-
)
380+
decoder = decoder or load_json
400381

401-
if "jsonl" in key:
402-
decoded_data = [decoder(jline) for jline in file.read().splitlines()]
403-
else:
404-
decoded_data = decoder(file.read())
405-
if not isinstance(decoded_data, list):
406-
decoded_data = [decoded_data]
382+
if "jsonl" in key:
383+
decoded_data = [decoder(jline) for jline in file_data.splitlines()]
384+
else:
385+
decoded_data = decoder(file_data)
386+
if not isinstance(decoded_data, list):
387+
decoded_data = [decoded_data]
388+
389+
raise_error = not decoded_data or len(decoded_data) == 0
390+
391+
except ClientError:
392+
# No such object exists
393+
raise_error = True
394+
395+
if raise_error:
396+
raise MPRestError(f"No object found: s3://{bucket}/{key}")
407397

408398
return decoded_data, len(decoded_data) # type: ignore
409399

@@ -467,14 +457,9 @@ def _query_resource(
467457
criteria["_fields"] = ",".join(fields)
468458

469459
try:
470-
url = self.endpoint
471-
if suburl:
472-
url = urljoin(self.endpoint, suburl)
473-
if not url.endswith("/"):
474-
url += "/"
460+
url = validate_endpoint(self.endpoint, suffix=suburl)
475461

476462
if query_s3:
477-
db_version = self.db_version.replace(".", "-")
478463
if "/" not in self.suffix:
479464
suffix = self.suffix
480465
elif self.suffix == "molecules/summary":
@@ -490,7 +475,7 @@ def _query_resource(
490475
bucket_suffix, prefix = "parsed", "tasks_atomate2"
491476
else:
492477
bucket_suffix = "build"
493-
prefix = f"collections/{db_version}/{suffix}"
478+
prefix = f"collections/{self.db_version.replace('.', '-')}/{suffix}"
494479

495480
bucket = f"materialsproject-{bucket_suffix}"
496481
paginator = self.s3_client.get_paginator("list_objects_v2")
@@ -618,15 +603,15 @@ def _submit_requests( # noqa
618603

619604
bare_url_len = len(url_string)
620605
max_param_str_length = (
621-
MAPIClientSettings().MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore
606+
MAPI_CLIENT_SETTINGS.MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore
622607
)
623608

624609
# Next, check if default number of parallel requests works.
625610
# If not, make slice size the minimum number of param entries
626611
# contained in any substring of length max_param_str_length.
627612
param_length = len(criteria[parallel_param].split(","))
628613
slice_size = (
629-
int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1 # type: ignore
614+
int(param_length / MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS) or 1 # type: ignore
630615
)
631616

632617
url_param_string = quote(criteria[parallel_param])
@@ -687,7 +672,7 @@ def _submit_requests( # noqa
687672
new_limits = [chunk_size]
688673

689674
total_num_docs = 0
690-
total_data = {"data": []} # type: dict
675+
total_data: dict[str, list[Any]] = {"data": []}
691676

692677
# Obtain first page of results and get pagination information.
693678
# Individual total document limits (subtotal) will potentially
@@ -907,14 +892,14 @@ def _multi_thread(
907892
params_ind = 0
908893

909894
with ThreadPoolExecutor(
910-
max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS # type: ignore
895+
max_workers=MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS # type: ignore
911896
) as executor:
912897
# Get list of initial futures defined by max number of parallel requests
913898
futures = set()
914899

915900
for params in itertools.islice(
916901
params_gen,
917-
MAPIClientSettings().NUM_PARALLEL_REQUESTS, # type: ignore
902+
MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS, # type: ignore
918903
):
919904
future = executor.submit(
920905
func,
@@ -1276,7 +1261,7 @@ def _get_all_documents(
12761261
for key, entry in query_params.items()
12771262
if isinstance(entry, str)
12781263
and len(entry.split(",")) > 0
1279-
and key not in MAPIClientSettings().QUERY_NO_PARALLEL # type: ignore
1264+
and key not in MAPI_CLIENT_SETTINGS.QUERY_NO_PARALLEL # type: ignore
12801265
),
12811266
key=lambda item: item[1],
12821267
reverse=True,
@@ -1351,3 +1336,37 @@ def __str__(self): # pragma: no cover
13511336
f"{self.__class__.__name__} connected to {self.endpoint}\n\n"
13521337
f"Available fields: {', '.join(self.available_fields)}\n\n"
13531338
)
1339+
1340+
1341+
class CoreRester(BaseRester):
1342+
"""Define a BaseRester with extra features for core resters.
1343+
1344+
Enables lazy importing / initialization of sub resters
1345+
provided in `_sub_resters`, which should be a map
1346+
of endpoints names to LazyImport objects.
1347+
1348+
"""
1349+
1350+
_sub_resters: dict[str, LazyImport] = {}
1351+
1352+
def __init__(self, **kwargs):
1353+
"""Ensure that sub resters are unset on re-init."""
1354+
super().__init__(**kwargs)
1355+
self.sub_resters = {k: v.copy() for k, v in self._sub_resters.items()}
1356+
1357+
def __getattr__(self, v: str):
1358+
if v in self.sub_resters:
1359+
if self.sub_resters[v]._obj is None:
1360+
self.sub_resters[v](
1361+
api_key=self.api_key,
1362+
endpoint=self.base_endpoint,
1363+
include_user_agent=self._include_user_agent,
1364+
session=self.session,
1365+
use_document_model=self.use_document_model,
1366+
headers=self.headers,
1367+
mute_progress_bars=self.mute_progress_bars,
1368+
)
1369+
return self.sub_resters[v]
1370+
1371+
def __dir__(self):
1372+
return dir(self.__class__) + list(self._sub_resters)

0 commit comments

Comments
 (0)