diff --git a/explaining_framework/explainers/wrappers/test.py b/explaining_framework/explainers/wrappers/test.py index 313d357..ba2a778 100644 --- a/explaining_framework/explainers/wrappers/test.py +++ b/explaining_framework/explainers/wrappers/test.py @@ -2,12 +2,16 @@ import traceback import torch import torch.nn as nn +from from_captum import CaptumWrapper +from from_graphxai import GraphXAIWrapper from torch_geometric.data import Batch, Data from torch_geometric.explain import Explainer from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool -from from_captum import CaptumWrapper -from from_graphxai import GraphXAIWrapper +from explaining_framework.explaining_framework.metric.accuracy import Accuracy +from explaining_framework.explaining_framework.metric.fidelity import Fidelity +from explaining_framework.explaining_framework.metric.robust import Attack +from explaining_framework.explaining_framework.metric.sparsity import Sparsity __all__captum = [ "LRP", @@ -27,16 +31,16 @@ __all__captum = [ __all__graphxai = [ "CAM", - "GradCAM", - "GNN_LRP", - "GradExplainer", - "GuidedBackPropagation", - "IntegratedGradients", - "PGExplainer", - "PGMExplainer", - "RandomExplainer", - "SubgraphX", - "GraphMASK", + # "GradCAM", + # "GNN_LRP", + # "GradExplainer", + # "GuidedBackPropagation", + # "IntegratedGradients", + # "PGExplainer", + # "PGMExplainer", + # "RandomExplainer", + # "SubgraphX", + # "GraphMASK", ] @@ -78,7 +82,7 @@ for epoch in range(1, 2): target = torch.LongTensor([[0]]) for kind in ["node"]: - for name in __all__captum + __all__graphxai: + for name in __all__graphxai: if name in __all__captum: explaining_algorithm = CaptumWrapper(name) elif name in __all__graphxai: @@ -101,6 +105,7 @@ for kind in ["node"]: task_level=kind, return_type="raw", ), + threshold_config=dict(threshold_type="hard", value=0.5), ) explanation = explainer( x=batch.x, @@ -108,7 +113,22 @@ for kind in ["node"]: index=int(target), target=batch.y, ) - print(explanation.__dict__) + explanation_threshold = explanation._apply_mask( + node_mask=explanation.node_mask, edge_mask=explanation.edge_mask + ) + + for f_name in ["precision_score", + "precision_score", + "jaccard_score", + "roc_auc_score", + "f1_score", + "accuracy_score"]: + acc = Accuracy(f_name) + gt = torch.ones_like(x)/2 + out = acc.forward(mask=explanation_threshold.node_mask,target=gt) + print(out) + + except Exception as e: print(str(e)) pass