Skip to content

Commit c57e4ce

Browse files
committed
fixed bug in counterfactuals
1 parent dc90cd1 commit c57e4ce

File tree

2 files changed

+40
-23
lines changed

2 files changed

+40
-23
lines changed

XAI_main.ipynb

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4861,30 +4861,27 @@
48614861
},
48624862
{
48634863
"cell_type": "code",
4864-
"execution_count": 20,
4864+
"execution_count": 21,
48654865
"metadata": {},
48664866
"outputs": [
48674867
{
4868-
"ename": "NameError",
4869-
"evalue": "name 'se' is not defined",
4870-
"output_type": "error",
4871-
"traceback": [
4872-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
4873-
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
4874-
"\u001b[0;32m<ipython-input-20-8a2f83a951d2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mdel\u001b[0m \u001b[0mse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
4875-
"\u001b[0;31mNameError\u001b[0m: name 'se' is not defined"
4868+
"name": "stdout",
4869+
"output_type": "stream",
4870+
"text": [
4871+
"WARNING:tensorflow:From /home/ambreesh/Documents/amb_venv/lib/python3.6/site-packages/shap/explainers/tf_utils.py:28: The name tf.keras.backend.get_session is deprecated. Please use tf.compat.v1.keras.backend.get_session instead.\n",
4872+
"\n"
4873+
]
4874+
},
4875+
{
4876+
"name": "stderr",
4877+
"output_type": "stream",
4878+
"text": [
4879+
"keras is no longer supported, please use tf.keras instead.\n",
4880+
"WARNING:tensorflow:From /home/ambreesh/Documents/amb_venv/lib/python3.6/site-packages/shap/explainers/tf_utils.py:28: The name tf.keras.backend.get_session is deprecated. Please use tf.compat.v1.keras.backend.get_session instead.\n",
4881+
"\n"
48764882
]
48774883
}
48784884
],
4879-
"source": [
4880-
"del se"
4881-
]
4882-
},
4883-
{
4884-
"cell_type": "code",
4885-
"execution_count": null,
4886-
"metadata": {},
4887-
"outputs": [],
48884885
"source": [
48894886
"se = SHAP_Explainer(blend_alpha=0.45)\n",
48904887
"shap_out = se.explain(model, X_train[:1000], X_test/255.)\n",
@@ -4907,17 +4904,35 @@
49074904
},
49084905
{
49094906
"cell_type": "code",
4910-
"execution_count": null,
4907+
"execution_count": 22,
49114908
"metadata": {},
4912-
"outputs": [],
4909+
"outputs": [
4910+
{
4911+
"name": "stdout",
4912+
"output_type": "stream",
4913+
"text": [
4914+
"dict_keys(['Gradient_Inputs', 'Saliency', 'Integrated_Gradients', 'e-LRP', 'Occlusion', 'LIME', 'SHAP'])\n"
4915+
]
4916+
},
4917+
{
4918+
"data": {
4919+
"text/plain": [
4920+
"True"
4921+
]
4922+
},
4923+
"execution_count": 22,
4924+
"metadata": {},
4925+
"output_type": "execute_result"
4926+
}
4927+
],
49134928
"source": [
49144929
"print(xai_results.keys())\n",
49154930
"dump_pickle(xai_results, \"results/xai_results_%s.pkl\"%(DATASET))"
49164931
]
49174932
},
49184933
{
49194934
"cell_type": "code",
4920-
"execution_count": null,
4935+
"execution_count": 23,
49214936
"metadata": {},
49224937
"outputs": [],
49234938
"source": [

_counterfactual.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,14 @@ def explain(
8484
if not isinstance(anomalous_samples, np.ndarray): anomalous_samples = np.array(anomalous_samples)
8585

8686
# Reconstructions and error calculations can be reduced to being done once
87-
normal_reconstructions = self.model.predict(normal_samples)[0]
87+
normal_reconstructions = self.model.predict(normal_samples)
8888
normal_mse = self.get_mse(normal_samples, normal_reconstructions)
8989

90-
anomalous_reconstructions = self.model.predict(anomalous_samples)[0]
90+
anomalous_reconstructions = self.model.predict(anomalous_samples)
9191
anomalous_mse = self.get_mse(anomalous_samples, anomalous_reconstructions)
9292

93+
if self.debug: print("NORMAL:", normal_samples.shape, normal_reconstructions.shape, normal_mse.shape)
94+
if self.debug: print("ANOMALY:", anomalous_samples.shape, anomalous_reconstructions.shape, anomalous_mse.shape)
9395
if self.debug: print("Models loaded. Starting analysis")
9496

9597
# Run the counterfactual part

0 commit comments

Comments
 (0)