Skip to content

Commit 509802d

Browse files
gaugupamit-sharma
andauthored
Reduce number inferences in dice random (#127)
* Redice number of inference calls DiceRandom Signed-off-by: gaugup <gaugup@microsoft.com> * Added comment Signed-off-by: gaugup <gaugup@microsoft.com> * corrected typo to model_predictions Co-authored-by: Amit Sharma <amit_sharma@live.com>
1 parent e69f7ca commit 509802d

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

dice_ml/explainer_interfaces/dice_random.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,18 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range, d
6363
else: # compute the new ranges based on user input
6464
self.feature_range, feature_ranges_orig = self.data_interface.get_features_range(permitted_range)
6565

66+
# Do predictions once on the query_instance and reuse across to reduce the number
67+
# inferences.
68+
model_predictions = self.predict_fn(query_instance)
69+
6670
# number of output nodes of ML model
6771
self.num_output_nodes = None
6872
if self.model.model_type == "classifier":
69-
self.num_output_nodes = self.predict_fn(query_instance).shape[1]
73+
self.num_output_nodes = model_predictions.shape[1]
7074

7175
# query_instance need no transformation for generating CFs using random sampling.
7276
# find the predicted value of query_instance
73-
test_pred = self.predict_fn(query_instance)[0]
77+
test_pred = model_predictions[0]
7478
if self.model.model_type == 'classifier':
7579
self.target_cf_class = self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes)
7680
elif self.model.model_type == 'regressor':

0 commit comments

Comments
 (0)