@@ -256,6 +256,22 @@ def test_desired_class(
256256 assert all (ans .cf_examples_list [0 ].final_cfs_df_sparse [exp .data_interface .outcome_name ].values ==
257257 [desired_class ] * 2 )
258258
259+ exp .serialize_explainer (method + '.pkl' )
260+ new_exp = ExplainerBase .deserialize_explainer (method + '.pkl' )
261+
262+ ans = new_exp .generate_counterfactuals (query_instances = sample_custom_query_2 ,
263+ features_to_vary = 'all' ,
264+ total_CFs = 2 , desired_class = desired_class ,
265+ proximity_weight = 0.2 , sparsity_weight = 0.2 ,
266+ diversity_weight = 5.0 ,
267+ categorical_penalty = 0.1 ,
268+ permitted_range = None )
269+ if method != 'kdtree' :
270+ assert all (ans .cf_examples_list [0 ].final_cfs_df [new_exp .data_interface .outcome_name ].values == [desired_class ] * 2 )
271+ else :
272+ assert all (ans .cf_examples_list [0 ].final_cfs_df_sparse [new_exp .data_interface .outcome_name ].values ==
273+ [desired_class ] * 2 )
274+
259275 @pytest .mark .parametrize (("desired_class" , "total_CFs" , "permitted_range" ),
260276 [(1 , 1 , {'Numerical' : [10 , 150 ]})])
261277 def test_permitted_range (
@@ -349,6 +365,29 @@ def test_desired_class(
349365 [desired_class ] * total_CFs )
350366 assert all (i == desired_class for i in exp .cfs_preds )
351367
368+ exp .serialize_explainer (method + '.pkl' )
369+ new_exp = ExplainerBase .deserialize_explainer (method + '.pkl' )
370+
371+ if method != 'genetic' :
372+ ans = new_exp .generate_counterfactuals (
373+ query_instances = sample_custom_query_2 ,
374+ total_CFs = total_CFs , desired_class = desired_class )
375+ else :
376+ ans = new_exp .generate_counterfactuals (
377+ query_instances = sample_custom_query_2 ,
378+ total_CFs = total_CFs , desired_class = desired_class ,
379+ initialization = genetic_initialization )
380+
381+ assert ans is not None
382+ if method != 'kdtree' :
383+ assert all (
384+ ans .cf_examples_list [0 ].final_cfs_df [new_exp .data_interface .outcome_name ].values == [desired_class ] * total_CFs )
385+ else :
386+ assert all (
387+ ans .cf_examples_list [0 ].final_cfs_df_sparse [new_exp .data_interface .outcome_name ].values ==
388+ [desired_class ] * total_CFs )
389+ assert all (i == desired_class for i in new_exp .cfs_preds )
390+
352391 # When no elements in the desired_class are present in the training data
353392 @pytest .mark .parametrize (("desired_class" , "total_CFs" ), [(100 , 3 ), ('opposite' , 3 )])
354393 def test_unsupported_multiclass (
@@ -422,6 +461,16 @@ def test_numeric_categories(self, desired_range, method, create_housing_data):
422461
423462 assert cf_explanation is not None
424463
464+ exp .serialize_explainer ("explainer.pkl" )
465+ new_exp = ExplainerBase .deserialize_explainer ("explainer.pkl" )
466+
467+ cf_explanation = new_exp .generate_counterfactuals (
468+ query_instances = x_test .iloc [0 :1 ],
469+ total_CFs = 10 ,
470+ desired_range = desired_range )
471+
472+ assert cf_explanation is not None
473+
425474
426475class TestExplainerBase :
427476
0 commit comments