1+
2+ import os
3+
4+ import numpy as np
5+ import pytest
6+
7+ from emmet .core .similarity import SimilarityScorer , SimilarityEntry
8+ from pymatgen .core import Structure
9+
10+ from mp_api .client .core import MPRestError
11+ from mp_api .client .routes .materials .similarity import SimilarityRester
12+
13+ from core_function import client_search_testing
14+
15+ @pytest .fixture (scope = "module" )
16+ def test_struct ():
17+ poscar = """Al2
18+ 1.0
19+ 3.3335729972004917 0.0000000000000000 1.9246389981090721
20+ 1.1111909992801432 3.1429239987499362 1.9246389992542632
21+ 0.0000000000000000 0.0000000000000000 3.8492780000000000
22+ Al
23+ 2
24+ direct
25+ 0.875 0.875 0.875 Al
26+ 0.125 0.125 0.125 Al
27+ """
28+ return Structure .from_str (poscar ,fmt = "poscar" )
29+
30+ @pytest .mark .skipif (os .getenv ("MP_API_KEY" ) is None , reason = "No API key found." )
31+ def test_client ():
32+ client_search_testing (
33+ search_method = SimilarityRester ().search ,
34+ excluded_params = [
35+ "num_chunks" ,
36+ "chunk_size" ,
37+ "all_fields" ,
38+ "fields" ,
39+ ],
40+ alt_name_dict = {
41+ "material_ids" : "material_id" ,
42+ },
43+ custom_field_tests = {
44+ "material_ids" : ["mp-149" ,"mp-13" ],
45+ "material_ids" : "mp-149"
46+ },
47+ sub_doc_fields = [],
48+ )
49+
50+ @pytest .mark .skipif (os .getenv ("MP_API_KEY" ) is None , reason = "No API key found." )
51+ def test_similarity_vector_search (test_struct ):
52+
53+ rester = SimilarityRester ()
54+ fv = rester .fingerprint_structure (test_struct )
55+ assert isinstance (fv ,np .ndarray )
56+ assert len (fv ) == 122
57+ assert isinstance (rester ._fingerprinter ,SimilarityScorer )
58+
59+
60+ get_top = 5
61+ sim_entries = rester .find_similar ("mp-149" ,top = get_top )
62+ assert all (
63+ isinstance (entry ,SimilarityEntry ) for entry in sim_entries
64+ )
65+ assert len (sim_entries ) == get_top
66+
67+ sim_dict_entries = SimilarityRester (use_document_model = False ).find_similar ("mp-149" ,top = get_top )
68+ assert all (
69+ isinstance (entry ,dict ) and all (
70+ k in entry for k in SimilarityEntry .model_fields
71+ )
72+ for entry in sim_dict_entries
73+ )
74+
75+ with pytest .raises (MPRestError ,match = "No similarity data available for" ):
76+ _ = rester .find_similar ("mp-0" )
77+
78+ assert all (
79+ isinstance (entry ,SimilarityEntry ) and isinstance (entry .dissimilarity ,float )
80+ for entry in rester .find_similar (test_struct , top = 2 ,)
81+ )
82+
83+ with pytest .raises (MPRestError ,match = "Please submit a pymatgen Structure or MP ID" ):
84+ _ = rester .find_similar (fv )
0 commit comments