147 lines
4.3 KiB
Python
147 lines
4.3 KiB
Python
import traceback
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from explaining_framework.metric.accuracy import Accuracy
|
|
from explaining_framework.metric.fidelity import Fidelity
|
|
from explaining_framework.metric.robust import Attack
|
|
from explaining_framework.metric.sparsity import Sparsity
|
|
from torch_geometric.data import Batch, Data
|
|
from torch_geometric.explain import Explainer
|
|
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
|
|
|
|
from from_captum import CaptumWrapper
|
|
from from_graphxai import GraphXAIWrapper
|
|
|
|
__all__captum = [
|
|
"LRP",
|
|
"DeepLift",
|
|
"DeepLiftShap",
|
|
"FeatureAblation",
|
|
"FeaturePermutation",
|
|
"GradientShap",
|
|
"GuidedBackprop",
|
|
"GuidedGradCam",
|
|
"InputXGradient",
|
|
"IntegratedGradients",
|
|
"Lime",
|
|
"Occlusion",
|
|
"Saliency",
|
|
]
|
|
|
|
__all__graphxai = [
|
|
"CAM",
|
|
# "GradCAM",
|
|
# "GNN_LRP",
|
|
# "GradExplainer",
|
|
# "GuidedBackPropagation",
|
|
# "IntegratedGradients",
|
|
# "PGExplainer",
|
|
# "PGMExplainer",
|
|
# "RandomExplainer",
|
|
# "SubgraphX",
|
|
# "GraphMASK",
|
|
]
|
|
|
|
|
|
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
|
|
size_F = 4
|
|
size_O = in_channels = 6
|
|
x = torch.ones((3, size_F))
|
|
y = torch.tensor([1], dtype=torch.long)
|
|
loss = nn.CrossEntropyLoss()
|
|
|
|
data = Data(x=x, edge_index=edge_index, y=y)
|
|
batch = Batch().from_data_list([data])
|
|
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, dim_in, dim_out):
|
|
super().__init__()
|
|
self.dim_in = dim_in
|
|
self.dim_out = dim_out
|
|
self.conv = GCNConv(dim_in, dim_out)
|
|
|
|
def forward(self, x, edge_index):
|
|
x = self.conv(x, edge_index)
|
|
x = global_mean_pool(x, torch.LongTensor([0]))
|
|
return x
|
|
|
|
|
|
model = Model(size_F, size_O)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
|
|
|
for epoch in range(1, 2):
|
|
model.train()
|
|
optimizer.zero_grad()
|
|
out = model(batch.x, batch.edge_index)
|
|
# lossee = loss(out, torch.ones(x.shape[0], size_O))
|
|
lossee = loss(out, torch.ones(1, size_O))
|
|
lossee.backward()
|
|
optimizer.step()
|
|
target = torch.LongTensor([[0]])
|
|
|
|
for kind in ["graph"]:
|
|
for name in __all__graphxai:
|
|
if name in __all__captum:
|
|
explaining_algorithm = CaptumWrapper(name)
|
|
elif name in __all__graphxai:
|
|
explaining_algorithm = GraphXAIWrapper(
|
|
name, in_channels=in_channels, criterion="cross-entropy"
|
|
)
|
|
|
|
print(name)
|
|
try:
|
|
explainer = Explainer(
|
|
model=model,
|
|
algorithm=explaining_algorithm,
|
|
explainer_config=dict(
|
|
explanation_type="phenomenon",
|
|
node_mask_type="object",
|
|
edge_mask_type="object",
|
|
),
|
|
model_config=dict(
|
|
mode="regression",
|
|
task_level=kind,
|
|
return_type="raw",
|
|
),
|
|
threshold_config=dict(threshold_type="hard", value=0.5),
|
|
)
|
|
explanation = explainer(
|
|
x=batch.x,
|
|
edge_index=batch.edge_index,
|
|
index=int(target),
|
|
target=batch.y,
|
|
)
|
|
print(explanation.__dict__)
|
|
# explanation.__setattr__(
|
|
# "model_prediction", explainer.get_prediction(x, edge_index)
|
|
# )
|
|
explanation_threshold = explanation._apply_masks(
|
|
node_mask=torch.ones_like(explanation.node_mask).bool()
|
|
)
|
|
|
|
print(explanation_threshold.__dict__)
|
|
|
|
for f_name in [
|
|
"gaussian_noise",
|
|
"add_edge",
|
|
"remove_edge",
|
|
"remove_node",
|
|
"pgd",
|
|
"fgsm",
|
|
]:
|
|
print(f_name)
|
|
acc = Attack(name=f_name, model=model, loss=loss)
|
|
# gt = torch.ones_like(explanation_threshold.node_mask) / 2
|
|
# mask = explanation_threshold.node_mask.bool()
|
|
# target = (1 - gt).bool()
|
|
# target[1] = False
|
|
# print(mask, target)
|
|
out = acc.forward(explanation)
|
|
print(out)
|
|
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
# print(str(e))
|
|
pass
|