@@ -256,6 +256,22 @@ def test_desired_class(
256256 assert all (ans .cf_examples_list [0 ].final_cfs_df_sparse [exp .data_interface .outcome_name ].values ==
257257 [desired_class ] * 2 )
258258
259+ exp .serialize_explainer (method + '.pkl' )
260+ new_exp = ExplainerBase .deserialize_explainer (method + '.pkl' )
261+
262+ ans = new_exp .generate_counterfactuals (query_instances = sample_custom_query_2 ,
263+ features_to_vary = 'all' ,
264+ total_CFs = 2 , desired_class = desired_class ,
265+ proximity_weight = 0.2 , sparsity_weight = 0.2 ,
266+ diversity_weight = 5.0 ,
267+ categorical_penalty = 0.1 ,
268+ permitted_range = None )
269+ if method != 'kdtree' :
270+ assert all (ans .cf_examples_list [0 ].final_cfs_df [new_exp .data_interface .outcome_name ].values == [desired_class ] * 2 )
271+ else :
272+ assert all (ans .cf_examples_list [0 ].final_cfs_df_sparse [new_exp .data_interface .outcome_name ].values ==
273+ [desired_class ] * 2 )
274+
259275 @pytest .mark .parametrize (("desired_class" , "total_CFs" , "permitted_range" ),
260276 [(1 , 1 , {'Numerical' : [10 , 150 ]})])
261277 def test_permitted_range (
@@ -349,6 +365,30 @@ def test_desired_class(
349365 [desired_class ] * total_CFs )
350366 assert all (i == desired_class for i in exp .cfs_preds )
351367
368+ exp .serialize_explainer (method + '.pkl' )
369+ new_exp = ExplainerBase .deserialize_explainer (method + '.pkl' )
370+
371+ if method != 'genetic' :
372+ ans = new_exp .generate_counterfactuals (
373+ query_instances = sample_custom_query_2 ,
374+ total_CFs = total_CFs , desired_class = desired_class )
375+ else :
376+ ans = new_exp .generate_counterfactuals (
377+ query_instances = sample_custom_query_2 ,
378+ total_CFs = total_CFs , desired_class = desired_class ,
379+ initialization = genetic_initialization )
380+
381+ assert ans is not None
382+ if method != 'kdtree' :
383+ assert all (
384+ ans .cf_examples_list [0 ].final_cfs_df [
385+ new_exp .data_interface .outcome_name ].values == [desired_class ] * total_CFs )
386+ else :
387+ assert all (
388+ ans .cf_examples_list [0 ].final_cfs_df_sparse [new_exp .data_interface .outcome_name ].values ==
389+ [desired_class ] * total_CFs )
390+ assert all (i == desired_class for i in new_exp .cfs_preds )
391+
352392 # When no elements in the desired_class are present in the training data
353393 @pytest .mark .parametrize (("desired_class" , "total_CFs" ), [(100 , 3 ), ('opposite' , 3 )])
354394 def test_unsupported_multiclass (
@@ -422,6 +462,16 @@ def test_numeric_categories(self, desired_range, method, create_housing_data):
422462
423463 assert cf_explanation is not None
424464
465+ exp .serialize_explainer ("explainer.pkl" )
466+ new_exp = ExplainerBase .deserialize_explainer ("explainer.pkl" )
467+
468+ cf_explanation = new_exp .generate_counterfactuals (
469+ query_instances = x_test .iloc [0 :1 ],
470+ total_CFs = 10 ,
471+ desired_range = desired_range )
472+
473+ assert cf_explanation is not None
474+
425475
426476class TestExplainerBase :
427477
0 commit comments