Debugging
This commit is contained in:
parent
e9ef1cca9a
commit
9168cdb56b
|
@ -2,12 +2,16 @@ import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from from_captum import CaptumWrapper
|
||||||
|
from from_graphxai import GraphXAIWrapper
|
||||||
from torch_geometric.data import Batch, Data
|
from torch_geometric.data import Batch, Data
|
||||||
from torch_geometric.explain import Explainer
|
from torch_geometric.explain import Explainer
|
||||||
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
|
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
|
||||||
|
|
||||||
from from_captum import CaptumWrapper
|
from explaining_framework.explaining_framework.metric.accuracy import Accuracy
|
||||||
from from_graphxai import GraphXAIWrapper
|
from explaining_framework.explaining_framework.metric.fidelity import Fidelity
|
||||||
|
from explaining_framework.explaining_framework.metric.robust import Attack
|
||||||
|
from explaining_framework.explaining_framework.metric.sparsity import Sparsity
|
||||||
|
|
||||||
__all__captum = [
|
__all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -27,16 +31,16 @@ __all__captum = [
|
||||||
|
|
||||||
__all__graphxai = [
|
__all__graphxai = [
|
||||||
"CAM",
|
"CAM",
|
||||||
"GradCAM",
|
# "GradCAM",
|
||||||
"GNN_LRP",
|
# "GNN_LRP",
|
||||||
"GradExplainer",
|
# "GradExplainer",
|
||||||
"GuidedBackPropagation",
|
# "GuidedBackPropagation",
|
||||||
"IntegratedGradients",
|
# "IntegratedGradients",
|
||||||
"PGExplainer",
|
# "PGExplainer",
|
||||||
"PGMExplainer",
|
# "PGMExplainer",
|
||||||
"RandomExplainer",
|
# "RandomExplainer",
|
||||||
"SubgraphX",
|
# "SubgraphX",
|
||||||
"GraphMASK",
|
# "GraphMASK",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,7 +82,7 @@ for epoch in range(1, 2):
|
||||||
target = torch.LongTensor([[0]])
|
target = torch.LongTensor([[0]])
|
||||||
|
|
||||||
for kind in ["node"]:
|
for kind in ["node"]:
|
||||||
for name in __all__captum + __all__graphxai:
|
for name in __all__graphxai:
|
||||||
if name in __all__captum:
|
if name in __all__captum:
|
||||||
explaining_algorithm = CaptumWrapper(name)
|
explaining_algorithm = CaptumWrapper(name)
|
||||||
elif name in __all__graphxai:
|
elif name in __all__graphxai:
|
||||||
|
@ -101,6 +105,7 @@ for kind in ["node"]:
|
||||||
task_level=kind,
|
task_level=kind,
|
||||||
return_type="raw",
|
return_type="raw",
|
||||||
),
|
),
|
||||||
|
threshold_config=dict(threshold_type="hard", value=0.5),
|
||||||
)
|
)
|
||||||
explanation = explainer(
|
explanation = explainer(
|
||||||
x=batch.x,
|
x=batch.x,
|
||||||
|
@ -108,7 +113,22 @@ for kind in ["node"]:
|
||||||
index=int(target),
|
index=int(target),
|
||||||
target=batch.y,
|
target=batch.y,
|
||||||
)
|
)
|
||||||
print(explanation.__dict__)
|
explanation_threshold = explanation._apply_mask(
|
||||||
|
node_mask=explanation.node_mask, edge_mask=explanation.edge_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
for f_name in ["precision_score",
|
||||||
|
"precision_score",
|
||||||
|
"jaccard_score",
|
||||||
|
"roc_auc_score",
|
||||||
|
"f1_score",
|
||||||
|
"accuracy_score"]:
|
||||||
|
acc = Accuracy(f_name)
|
||||||
|
gt = torch.ones_like(x)/2
|
||||||
|
out = acc.forward(mask=explanation_threshold.node_mask,target=gt)
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
print(str(e))
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue