@@ -5,7 +5,8 @@ def __init__(self, **kwargs):
55 self .atom_obj = kwargs .get ('atom_obj' , None )
66 self .electric_charge_and_multiplicity = kwargs .get ('electric_charge_and_multiplicity' , None )
77 self .software_path = kwargs .get ('software_path' , None )
8- self .task_name = "omol"
8+ self .task_name = kwargs .get ('task_name' , "omol" )
9+ self .device_mode = kwargs .get ('device_mode' , "cpu" )
910 self .software_type = kwargs .get ('software_type' , None )
1011 print (f"ASE_FAIRCHEM: software_type = { self .software_type } " )
1112
@@ -17,8 +18,9 @@ def run(self): # fairchem.core: version 2.x.x
1718 except ImportError :
1819 raise ImportError ("FAIRChem.core modules not found" )
1920 # Load the prediction unit
20- predict_unit = load_predict_unit (path = self .software_path , device = "cpu" )
21-
21+ predict_unit = load_predict_unit (path = self .software_path , device = self .device_mode )
22+ print (f"ASE_FAIRCHEM: device_mode = { self .device_mode } " )
23+ print (f"ASE_FAIRCHEM: task_name = { self .task_name } " )
2224 # Set up the FAIRChem calculator
2325 fairchem_calc = FAIRChemCalculator (predict_unit = predict_unit , task_name = self .task_name )
2426 self .atom_obj .info = {"charge" : int (self .electric_charge_and_multiplicity [0 ]),
0 commit comments