33import itertools
44import os
55import warnings
6+ from collections import defaultdict
67from functools import cache , lru_cache
78from typing import TYPE_CHECKING
89
@@ -424,20 +425,15 @@ def get_task_ids_associated_with_material_id(
424425 if not tasks :
425426 return []
426427
427- calculations = (
428- tasks [0 ].calc_types # type: ignore
429- if self .use_document_model
430- else tasks [0 ]["calc_types" ] # type: ignore
431- )
428+ calculations = tasks [0 ]["calc_types" ]
432429
433430 if calc_types :
434431 return [
435432 task
436433 for task , calc_type in calculations .items ()
437434 if calc_type in calc_types
438435 ]
439- else :
440- return list (calculations .keys ())
436+ return list (calculations .keys ())
441437
442438 def get_structure_by_material_id (
443439 self , material_id : str , final : bool = True , conventional_unit_cell : bool = False
@@ -539,11 +535,7 @@ def get_material_id_references(self, material_id: str) -> list[str]:
539535 List of BibTeX references ([str])
540536 """
541537 docs = self .materials .provenance .search (material_ids = material_id )
542-
543- if not docs :
544- return []
545-
546- return docs [0 ].references if self .use_document_model else docs [0 ]["references" ] # type: ignore
538+ return docs [0 ]["references" ] if docs else []
547539
548540 def get_material_ids (
549541 self ,
@@ -558,17 +550,16 @@ def get_material_ids(
558550 Returns:
559551 List of all materials ids ([MPID])
560552 """
553+ inp_k = "formula"
561554 if isinstance (chemsys_formula , list ) or (
562555 isinstance (chemsys_formula , str ) and "-" in chemsys_formula
563556 ):
564- input_params = {"chemsys" : chemsys_formula }
565- else :
566- input_params = {"formula" : chemsys_formula }
557+ inp_k = "chemsys"
567558
568559 return sorted (
569- doc . material_id if self . use_document_model else doc ["material_id" ] # type: ignore
560+ doc ["material_id" ]
570561 for doc in self .materials .search (
571- ** input_params , # type: ignore
562+ ** { inp_k : chemsys_formula },
572563 all_fields = False ,
573564 fields = ["material_id" ],
574565 )
@@ -601,10 +592,8 @@ def get_structures(
601592 all_fields = False ,
602593 fields = ["structure" ],
603594 )
604- if not self .use_document_model :
605- return [doc ["structure" ] for doc in docs ] # type: ignore
606595
607- return [doc . structure for doc in docs ] # type: ignore
596+ return [doc [ " structure" ] for doc in docs ]
608597 else :
609598 structures = []
610599
@@ -613,12 +602,7 @@ def get_structures(
613602 all_fields = False ,
614603 fields = ["initial_structures" ],
615604 ):
616- initial_structures = (
617- doc .initial_structures # type: ignore
618- if self .use_document_model
619- else doc ["initial_structures" ] # type: ignore
620- )
621- structures .extend (initial_structures )
605+ structures .extend (doc ["initial_structures" ])
622606
623607 return structures
624608
@@ -723,7 +707,7 @@ def get_entries(
723707 if additional_criteria :
724708 input_params = {** input_params , ** additional_criteria }
725709
726- entries = []
710+ entries : set [ ComputedStructureEntry ] = set ()
727711
728712 fields = (
729713 ["entries" , "thermo_type" ]
@@ -738,24 +722,17 @@ def get_entries(
738722 )
739723
740724 for doc in docs :
741- entry_list = (
742- doc .entries .values () # type: ignore
743- if self .use_document_model
744- else doc ["entries" ].values () # type: ignore
745- )
725+ entry_list = doc ["entries" ].values ()
746726 for entry in entry_list :
747- entry_dict : dict = entry .as_dict () if self . monty_decode else entry # type: ignore
727+ entry_dict : dict = entry .as_dict () if hasattr ( entry , "as_dict" ) else entry # type: ignore
748728 if not compatible_only :
749729 entry_dict ["correction" ] = 0.0
750730 entry_dict ["energy_adjustments" ] = []
751731
752732 if property_data :
753- for property in property_data :
754- entry_dict ["data" ][property ] = (
755- doc .model_dump ()[property ] # type: ignore
756- if self .use_document_model
757- else doc [property ] # type: ignore
758- )
733+ entry_dict ["data" ] = {
734+ property : doc [property ] for property in property_data
735+ }
759736
760737 if conventional_unit_cell :
761738 entry_struct = Structure .from_dict (entry_dict ["structure" ])
@@ -776,15 +753,10 @@ def get_entries(
776753 if "n_atoms" in correction :
777754 correction ["n_atoms" ] *= site_ratio
778755
779- entry = (
780- ComputedStructureEntry .from_dict (entry_dict )
781- if self .monty_decode
782- else entry_dict
783- )
756+ # Need to store object to permit de-duplication
757+ entries .add (ComputedStructureEntry .from_dict (entry_dict ))
784758
785- entries .append (entry )
786-
787- return entries
759+ return [e if self .monty_decode else e .as_dict () for e in entries ]
788760
789761 def get_pourbaix_entries (
790762 self ,
@@ -1315,9 +1287,7 @@ def get_wulff_shape(self, material_id: str):
13151287 if not doc :
13161288 return None
13171289
1318- surfaces : list = (
1319- doc [0 ].surfaces if self .use_document_model else doc [0 ]["surfaces" ] # type: ignore
1320- )
1290+ surfaces : list = doc [0 ]["surfaces" ]
13211291
13221292 lattice = (
13231293 SpacegroupAnalyzer (structure ).get_conventional_standard_structure ().lattice
@@ -1387,17 +1357,8 @@ def get_charge_density_from_material_id(
13871357 if len (results ) == 0 :
13881358 return None
13891359
1390- latest_doc = max ( # type: ignore
1391- results ,
1392- key = lambda x : (
1393- x .last_updated # type: ignore
1394- if self .use_document_model
1395- else x ["last_updated" ]
1396- ), # type: ignore
1397- )
1398- task_id = (
1399- latest_doc .task_id if self .use_document_model else latest_doc ["task_id" ]
1400- )
1360+ latest_doc = max (results , key = lambda x : x ["last_updated" ])
1361+ task_id = latest_doc ["task_id" ]
14011362 return self .get_charge_density_from_task_id (task_id , inc_task_doc )
14021363
14031364 def get_download_info (self , material_ids , calc_types = None , file_patterns = None ):
@@ -1419,20 +1380,17 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None):
14191380 else []
14201381 )
14211382
1422- meta = {}
1383+ meta = defaultdict ( list )
14231384 for doc in self .materials .search ( # type: ignore
14241385 task_ids = material_ids ,
14251386 fields = ["calc_types" , "deprecated_tasks" , "material_id" ],
14261387 ):
1427- doc_dict : dict = doc .model_dump () if self .use_document_model else doc # type: ignore
1428- for task_id , calc_type in doc_dict ["calc_types" ].items ():
1388+ for task_id , calc_type in doc ["calc_types" ].items ():
14291389 if calc_types and calc_type not in calc_types :
14301390 continue
1431- mp_id = doc_dict ["material_id" ]
1432- if meta .get (mp_id ) is None :
1433- meta [mp_id ] = [{"task_id" : task_id , "calc_type" : calc_type }]
1434- else :
1435- meta [mp_id ].append ({"task_id" : task_id , "calc_type" : calc_type })
1391+ mp_id = doc ["material_id" ]
1392+ meta [mp_id ].append ({"task_id" : task_id , "calc_type" : calc_type })
1393+
14361394 if not meta :
14371395 raise ValueError (f"No tasks found for material id { material_ids } ." )
14381396
0 commit comments