Skip to content

Commit edc5415

Browse files
authored
Merge pull request #292 from interpretml/gaugup/CorrectVariableName
Change num_ouput_nodes to num_output_nodes in dice_tensorflow1.py
2 parents 06245b5 + 11fafe1 commit edc5415

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

dice_ml/explainer_interfaces/dice_tensorflow1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
225225
def predict_fn(self, input_instance):
226226
"""prediction function"""
227227
temp_preds = self.dice_sess.run(self.output_tensor, feed_dict={self.input_tensor: input_instance})
228-
return np.array([preds[(self.num_ouput_nodes-1):] for preds in temp_preds])
228+
return np.array([preds[(self.num_output_nodes-1):] for preds in temp_preds])
229229

230230
def predict_fn_for_sparsity(self, input_instance):
231231
"""prediction function for sparsity correction"""
@@ -239,22 +239,22 @@ def compute_yloss(self, method):
239239
if method == "l2_loss":
240240
temp_loss = tf.square(tf.subtract(
241241
self.model.get_output(self.cfs_frozen[i]), self.target_cf))
242-
temp_loss = temp_loss[:, (self.num_ouput_nodes-1):][0][0]
242+
temp_loss = temp_loss[:, (self.num_output_nodes-1):][0][0]
243243
elif method == "log_loss":
244244
temp_logits = tf.log(
245245
tf.divide(
246246
tf.abs(tf.subtract(self.model.get_output(self.cfs_frozen[i]), 0.000001)),
247247
tf.subtract(1.0, tf.abs(tf.subtract(self.model.get_output(
248248
self.cfs_frozen[i]), 0.000001)))))
249-
temp_logits = temp_logits[:, (self.num_ouput_nodes-1):]
249+
temp_logits = temp_logits[:, (self.num_output_nodes-1):]
250250
temp_loss = tf.nn.sigmoid_cross_entropy_with_logits(
251251
logits=temp_logits, labels=self.target_cf)[0][0]
252252
elif method == "hinge_loss":
253253
temp_logits = tf.log(
254254
tf.divide(
255255
tf.abs(tf.subtract(self.model.get_output(self.cfs_frozen[i]), 0.000001)),
256256
tf.subtract(1.0, tf.abs(tf.subtract(self.model.get_output(self.cfs_frozen[i]), 0.000001)))))
257-
temp_logits = temp_logits[:, (self.num_ouput_nodes-1):]
257+
temp_logits = temp_logits[:, (self.num_output_nodes-1):]
258258
temp_loss = tf.losses.hinge_loss(
259259
logits=temp_logits, labels=self.target_cf)
260260

0 commit comments

Comments
 (0)