New fixes and new features
This commit is contained in:
parent
6cf1d64d3a
commit
ea0a5dd86e
|
@ -1,3 +1,5 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation,
|
||||||
FeaturePermutation, GradientShap, GuidedBackprop,
|
FeaturePermutation, GradientShap, GuidedBackprop,
|
||||||
|
@ -96,7 +98,7 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||||
"Occlusion",
|
"Occlusion",
|
||||||
"Saliency",
|
"Saliency",
|
||||||
]:
|
]:
|
||||||
if task_level not in [ModelTaskLevel.node, ModelTaskLevel.graph]:
|
if task_level not in [ModelTaskLevel.graph]:
|
||||||
logging.error(f"Task level '{task_level.value}' not supported")
|
logging.error(f"Task level '{task_level.value}' not supported")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -209,11 +211,12 @@ class CaptumWrapper(ExplainerAlgorithm):
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
edge_index: Tensor,
|
edge_index: Tensor,
|
||||||
|
index: int,
|
||||||
target: int,
|
target: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
mask_type = self._get_mask_type()
|
mask_type = self._get_mask_type()
|
||||||
converted_model = to_captum_model(model, mask_type=mask_type, output_idx=target)
|
converted_model = to_captum_model(model, mask_type=mask_type, output_idx=index)
|
||||||
self.captum_method = self._load_captum_method(converted_model)
|
self.captum_method = self._load_captum_method(converted_model)
|
||||||
inputs, additional_forward_args = to_captum_input(
|
inputs, additional_forward_args = to_captum_input(
|
||||||
x, edge_index, mask_type=mask_type
|
x, edge_index, mask_type=mask_type
|
||||||
|
|
|
@ -225,17 +225,6 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
|
|
||||||
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
return node_mask, edge_mask, node_feat_mask, edge_feat_mask
|
||||||
|
|
||||||
def _parse_method_args(self, method, **kwargs):
|
|
||||||
signature = inspect.signature(method)
|
|
||||||
args = tuple(
|
|
||||||
[
|
|
||||||
kwargs[k.name]
|
|
||||||
for k in signature.parameters.values()
|
|
||||||
if k.name in kwargs.keys()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return args
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
|
@ -243,47 +232,27 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
edge_index: Tensor,
|
edge_index: Tensor,
|
||||||
target: Tensor,
|
target: Tensor,
|
||||||
index: Optional[Union[int, Tensor]] = None,
|
index: Optional[Union[int, Tensor]] = None,
|
||||||
target_index: Optional[int] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
mask_type = self._get_mask_type()
|
mask_type = self._get_mask_type()
|
||||||
self.graphxai_method = self._load_graphxai_method(model)
|
self.graphxai_method = self._load_graphxai_method(model)
|
||||||
# IF CRITERION = MSE:
|
|
||||||
# if (
|
|
||||||
# self.name in ["IntegratedGradients", "GradExplainer"]
|
|
||||||
# and "label" in kwargs.keys()
|
|
||||||
# ):
|
|
||||||
# kwargs["label"] = kwargs["label"].float()
|
|
||||||
if (
|
|
||||||
self.name in ["PGMExplainer", "RandomExplainer"]
|
|
||||||
and "label" in kwargs.keys()
|
|
||||||
):
|
|
||||||
kwargs.pop("label")
|
|
||||||
|
|
||||||
if self.model_config.task_level == ModelTaskLevel.node:
|
if self.model_config.task_level == ModelTaskLevel.node:
|
||||||
args = self._parse_method_args(
|
attr = self.graphxai_method.get_explanation_node(
|
||||||
self.graphxai_method.get_explanation_node,
|
|
||||||
x=x,
|
x=x,
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
node_idx=target,
|
label=target,
|
||||||
|
node_idx=index,
|
||||||
|
y=target,
|
||||||
)
|
)
|
||||||
|
|
||||||
attr = self.graphxai_method.get_explanation_node(*args, **kwargs)
|
|
||||||
elif self.model_config.task_level == ModelTaskLevel.graph:
|
elif self.model_config.task_level == ModelTaskLevel.graph:
|
||||||
args = self._parse_method_args(
|
attr = self.graphxai_method.get_explanation_graph(
|
||||||
self.graphxai_method.get_explanation_graph,
|
|
||||||
x=x,
|
x=x,
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
|
label=target,
|
||||||
|
y=target,
|
||||||
)
|
)
|
||||||
|
|
||||||
attr = self.graphxai_method.get_explanation_graph(*args, **kwargs)
|
|
||||||
elif self.model_config.task_level == ModelTaskLevel.edge:
|
elif self.model_config.task_level == ModelTaskLevel.edge:
|
||||||
args = self._parse_method_args(
|
|
||||||
self.graphxai_method.get_explanation_link,
|
|
||||||
x=x,
|
|
||||||
edge_index=edge_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
attr = self.graphxai_method.get_explanation_link(*args, **kwargs)
|
attr = self.graphxai_method.get_explanation_link(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{self.model_config.task_level} is not supported yet")
|
raise ValueError(f"{self.model_config.task_level} is not supported yet")
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
|
||||||
|
|
||||||
|
from from_captum import CaptumWrapper
|
||||||
|
from from_graphxai import GraphXAIWrapper
|
||||||
|
|
||||||
|
__all__captum = [
|
||||||
|
"LRP",
|
||||||
|
"DeepLift",
|
||||||
|
"DeepLiftShap",
|
||||||
|
"FeatureAblation",
|
||||||
|
"FeaturePermutation",
|
||||||
|
"GradientShap",
|
||||||
|
"GuidedBackprop",
|
||||||
|
"GuidedGradCam",
|
||||||
|
"InputXGradient",
|
||||||
|
"IntegratedGradients",
|
||||||
|
"Lime",
|
||||||
|
"Occlusion",
|
||||||
|
"Saliency",
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__graphxai = [
|
||||||
|
"CAM",
|
||||||
|
"GradCAM",
|
||||||
|
"GNN_LRP",
|
||||||
|
"GradExplainer",
|
||||||
|
"GuidedBackPropagation",
|
||||||
|
"IntegratedGradients",
|
||||||
|
"PGExplainer",
|
||||||
|
"PGMExplainer",
|
||||||
|
"RandomExplainer",
|
||||||
|
"SubgraphX",
|
||||||
|
"GraphMASK",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
|
||||||
|
size_F = 4
|
||||||
|
size_O = in_channels = 6
|
||||||
|
x = torch.ones((3, size_F))
|
||||||
|
y = torch.tensor([1], dtype=torch.long)
|
||||||
|
loss = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
data = Data(x=x, edge_index=edge_index, y=y)
|
||||||
|
batch = Batch().from_data_list([data])
|
||||||
|
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out):
|
||||||
|
super().__init__()
|
||||||
|
self.dim_in = dim_in
|
||||||
|
self.dim_out = dim_out
|
||||||
|
self.conv = GCNConv(dim_in, dim_out)
|
||||||
|
|
||||||
|
def forward(self, x, edge_index):
|
||||||
|
x = self.conv(x, edge_index)
|
||||||
|
x = global_mean_pool(x, torch.LongTensor([0]))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
model = Model(size_F, size_O)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
|
||||||
|
|
||||||
|
for epoch in range(1, 2):
|
||||||
|
model.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
out = model(batch.x, batch.edge_index)
|
||||||
|
# lossee = loss(out, torch.ones(x.shape[0], size_O))
|
||||||
|
lossee = loss(out, torch.ones(1, size_O))
|
||||||
|
lossee.backward()
|
||||||
|
optimizer.step()
|
||||||
|
target = torch.LongTensor([[0]])
|
||||||
|
|
||||||
|
for kind in ["node"]:
|
||||||
|
for name in __all__captum + __all__graphxai:
|
||||||
|
if name in __all__captum:
|
||||||
|
explaining_algorithm = CaptumWrapper(name)
|
||||||
|
elif name in __all__graphxai:
|
||||||
|
explaining_algorithm = GraphXAIWrapper(
|
||||||
|
name, in_channels=in_channels, criterion="cross-entropy"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(name)
|
||||||
|
try:
|
||||||
|
explainer = Explainer(
|
||||||
|
model=model,
|
||||||
|
algorithm=explaining_algorithm,
|
||||||
|
explainer_config=dict(
|
||||||
|
explanation_type="phenomenon",
|
||||||
|
node_mask_type="object",
|
||||||
|
edge_mask_type="object",
|
||||||
|
),
|
||||||
|
model_config=dict(
|
||||||
|
mode="classification",
|
||||||
|
task_level=kind,
|
||||||
|
return_type="raw",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
explanation = explainer(
|
||||||
|
x=batch.x,
|
||||||
|
edge_index=batch.edge_index,
|
||||||
|
index=int(target),
|
||||||
|
target=batch.y,
|
||||||
|
)
|
||||||
|
print(explanation.__dict__)
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
pass
|
|
@ -0,0 +1,19 @@
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
|
||||||
|
class Metric(ABC):
|
||||||
|
def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
|
||||||
|
self.name = name
|
||||||
|
self.model = model
|
||||||
|
if is_model_needed and model is None:
|
||||||
|
raise ValueError(f"{self.name} needs model to perform measurements")
|
||||||
|
|
||||||
|
def is_model_needed(self):
|
||||||
|
if "fidelity" in self.name:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, exp: Explanation, **kwargs) -> float:
|
||||||
|
pass
|
|
@ -140,6 +140,15 @@ class GraphStat(object):
|
||||||
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None
|
||||||
for name in names
|
for name in names
|
||||||
}
|
}
|
||||||
|
maps_add_assortativity = {
|
||||||
|
"assortativity": lambda x: torch_geometric.utils.assortativity(x.edge_index)
|
||||||
|
}
|
||||||
|
maps_add_homophily = {
|
||||||
|
f"homophily_{approach}": lambda x: torch_geometric.utils.homophily(
|
||||||
|
edge_index=x.edge_index, y=x.y, method=approach
|
||||||
|
)
|
||||||
|
for approach in ["edge", "node", "edge_insensitive"]
|
||||||
|
}
|
||||||
return maps
|
return maps
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
|
|
Loading…
Reference in New Issue