Skip to content

Commit 8a222a7

Browse files
add similarity tests for regular and vector search
1 parent 32714d3 commit 8a222a7

2 files changed

Lines changed: 88 additions & 1 deletion

File tree

mp_api/client/routes/materials/similarity.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,10 @@ def find_similar(
117117
feature_vector = self.fingerprint_structure(structure_or_mpid)
118118

119119
else:
120-
raise ValueError("Please submit a pymatgen Structure or MP ID.")
120+
raise MPRestError(
121+
"Please submit a pymatgen Structure or MP ID, found ."
122+
f"structure_or_mpid = {type(structure_or_mpid)}."
123+
)
121124

122125
top = top or MAX_VECTOR_SEARCH_RESULTS
123126
if not isinstance(top, int) or top < 1:

tests/materials/test_similarity.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
2+
import os
3+
4+
import numpy as np
5+
import pytest
6+
7+
from emmet.core.similarity import SimilarityScorer, SimilarityEntry
8+
from pymatgen.core import Structure
9+
10+
from mp_api.client.core import MPRestError
11+
from mp_api.client.routes.materials.similarity import SimilarityRester
12+
13+
from core_function import client_search_testing
14+
15+
@pytest.fixture(scope="module")
16+
def test_struct():
17+
poscar = """Al2
18+
1.0
19+
3.3335729972004917 0.0000000000000000 1.9246389981090721
20+
1.1111909992801432 3.1429239987499362 1.9246389992542632
21+
0.0000000000000000 0.0000000000000000 3.8492780000000000
22+
Al
23+
2
24+
direct
25+
0.875 0.875 0.875 Al
26+
0.125 0.125 0.125 Al
27+
"""
28+
return Structure.from_str(poscar,fmt="poscar")
29+
30+
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
31+
def test_client():
32+
client_search_testing(
33+
search_method=SimilarityRester().search,
34+
excluded_params=[
35+
"num_chunks",
36+
"chunk_size",
37+
"all_fields",
38+
"fields",
39+
],
40+
alt_name_dict={
41+
"material_ids": "material_id",
42+
},
43+
custom_field_tests = {
44+
"material_ids": ["mp-149","mp-13"],
45+
"material_ids": "mp-149"
46+
},
47+
sub_doc_fields=[],
48+
)
49+
50+
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
51+
def test_similarity_vector_search(test_struct):
52+
53+
rester = SimilarityRester()
54+
fv = rester.fingerprint_structure(test_struct)
55+
assert isinstance(fv,np.ndarray)
56+
assert len(fv) == 122
57+
assert isinstance(rester._fingerprinter,SimilarityScorer)
58+
59+
60+
get_top = 5
61+
sim_entries = rester.find_similar("mp-149",top=get_top)
62+
assert all(
63+
isinstance(entry,SimilarityEntry) for entry in sim_entries
64+
)
65+
assert len(sim_entries) == get_top
66+
67+
sim_dict_entries = SimilarityRester(use_document_model=False).find_similar("mp-149",top=get_top)
68+
assert all(
69+
isinstance(entry,dict) and all(
70+
k in entry for k in SimilarityEntry.model_fields
71+
)
72+
for entry in sim_dict_entries
73+
)
74+
75+
with pytest.raises(MPRestError,match="No similarity data available for"):
76+
_ = rester.find_similar("mp-0")
77+
78+
assert all(
79+
isinstance(entry,SimilarityEntry) and isinstance(entry.dissimilarity,float)
80+
for entry in rester.find_similar(test_struct, top = 2,)
81+
)
82+
83+
with pytest.raises(MPRestError,match="Please submit a pymatgen Structure or MP ID"):
84+
_ = rester.find_similar(fv)

0 commit comments

Comments
 (0)