Debugging

This commit is contained in:
araison 2022-12-17 18:10:04 +01:00
parent e9ef1cca9a
commit 9168cdb56b
1 changed files with 34 additions and 14 deletions

View File

@ -2,12 +2,16 @@ import traceback
import torch import torch
import torch.nn as nn import torch.nn as nn
from from_captum import CaptumWrapper
from from_graphxai import GraphXAIWrapper
from torch_geometric.data import Batch, Data from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer from torch_geometric.explain import Explainer
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
from from_captum import CaptumWrapper from explaining_framework.explaining_framework.metric.accuracy import Accuracy
from from_graphxai import GraphXAIWrapper from explaining_framework.explaining_framework.metric.fidelity import Fidelity
from explaining_framework.explaining_framework.metric.robust import Attack
from explaining_framework.explaining_framework.metric.sparsity import Sparsity
__all__captum = [ __all__captum = [
"LRP", "LRP",
@ -27,16 +31,16 @@ __all__captum = [
__all__graphxai = [ __all__graphxai = [
"CAM", "CAM",
"GradCAM", # "GradCAM",
"GNN_LRP", # "GNN_LRP",
"GradExplainer", # "GradExplainer",
"GuidedBackPropagation", # "GuidedBackPropagation",
"IntegratedGradients", # "IntegratedGradients",
"PGExplainer", # "PGExplainer",
"PGMExplainer", # "PGMExplainer",
"RandomExplainer", # "RandomExplainer",
"SubgraphX", # "SubgraphX",
"GraphMASK", # "GraphMASK",
] ]
@ -78,7 +82,7 @@ for epoch in range(1, 2):
target = torch.LongTensor([[0]]) target = torch.LongTensor([[0]])
for kind in ["node"]: for kind in ["node"]:
for name in __all__captum + __all__graphxai: for name in __all__graphxai:
if name in __all__captum: if name in __all__captum:
explaining_algorithm = CaptumWrapper(name) explaining_algorithm = CaptumWrapper(name)
elif name in __all__graphxai: elif name in __all__graphxai:
@ -101,6 +105,7 @@ for kind in ["node"]:
task_level=kind, task_level=kind,
return_type="raw", return_type="raw",
), ),
threshold_config=dict(threshold_type="hard", value=0.5),
) )
explanation = explainer( explanation = explainer(
x=batch.x, x=batch.x,
@ -108,7 +113,22 @@ for kind in ["node"]:
index=int(target), index=int(target),
target=batch.y, target=batch.y,
) )
print(explanation.__dict__) explanation_threshold = explanation._apply_mask(
node_mask=explanation.node_mask, edge_mask=explanation.edge_mask
)
for f_name in ["precision_score",
"precision_score",
"jaccard_score",
"roc_auc_score",
"f1_score",
"accuracy_score"]:
acc = Accuracy(f_name)
gt = torch.ones_like(x)/2
out = acc.forward(mask=explanation_threshold.node_mask,target=gt)
print(out)
except Exception as e: except Exception as e:
print(str(e)) print(str(e))
pass pass