Fixing bugs and adding new features
This commit is contained in:
parent
db0be0ddb7
commit
9ad5adb33e
|
@ -114,7 +114,7 @@ def set_cfg(explaining_cfg):
|
||||||
|
|
||||||
explaining_cfg.threshold_config.threshold_type = None
|
explaining_cfg.threshold_config.threshold_type = None
|
||||||
|
|
||||||
explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7]
|
explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)]
|
||||||
|
|
||||||
explaining_cfg.threshold_config.relu_and_normalize = True
|
explaining_cfg.threshold_config.relu_and_normalize = True
|
||||||
|
|
||||||
|
@ -128,8 +128,7 @@ def set_cfg(explaining_cfg):
|
||||||
explaining_cfg.metrics.force = False
|
explaining_cfg.metrics.force = False
|
||||||
|
|
||||||
explaining_cfg.attack = CN()
|
explaining_cfg.attack = CN()
|
||||||
explaining_cfg.attack.name = 'all'
|
explaining_cfg.attack.name = "all"
|
||||||
|
|
||||||
|
|
||||||
explaining_cfg.accelerator = "auto"
|
explaining_cfg.accelerator = "auto"
|
||||||
|
|
||||||
|
|
|
@ -31,16 +31,16 @@ __all__captum = [
|
||||||
|
|
||||||
__all__graphxai = [
|
__all__graphxai = [
|
||||||
"CAM",
|
"CAM",
|
||||||
# "GradCAM",
|
"GradCAM",
|
||||||
# "GNN_LRP",
|
"GNN_LRP",
|
||||||
# "GradExplainer",
|
"GradExplainer",
|
||||||
# "GuidedBackPropagation",
|
"GuidedBackPropagation",
|
||||||
# "IntegratedGradients",
|
"IntegratedGradients",
|
||||||
# "PGExplainer",
|
"PGExplainer",
|
||||||
# "PGMExplainer",
|
"PGMExplainer",
|
||||||
# "RandomExplainer",
|
"RandomExplainer",
|
||||||
# "SubgraphX",
|
"SubgraphX",
|
||||||
# "GraphMASK",
|
"GraphMASK",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,12 +82,12 @@ for epoch in range(1, 2):
|
||||||
target = torch.LongTensor([[0]])
|
target = torch.LongTensor([[0]])
|
||||||
|
|
||||||
for kind in ["graph"]:
|
for kind in ["graph"]:
|
||||||
for name in __all__graphxai:
|
for name in __all__graphxai + __all__captum:
|
||||||
if name in __all__captum:
|
if name in __all__captum:
|
||||||
explaining_algorithm = CaptumWrapper(name)
|
explaining_algorithm = CaptumWrapper(name)
|
||||||
elif name in __all__graphxai:
|
elif name in __all__graphxai:
|
||||||
explaining_algorithm = GraphXAIWrapper(
|
explaining_algorithm = GraphXAIWrapper(
|
||||||
name, in_channels=in_channels, criterion="cross-entropy"
|
name, in_channels=in_channels, criterion="cross_entropy"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(name)
|
print(name)
|
||||||
|
@ -105,7 +105,7 @@ for kind in ["graph"]:
|
||||||
task_level=kind,
|
task_level=kind,
|
||||||
return_type="raw",
|
return_type="raw",
|
||||||
),
|
),
|
||||||
threshold_config=dict(threshold_type="hard", value=0.5),
|
# threshold_config=dict(threshold_type=None, value=0.5),
|
||||||
)
|
)
|
||||||
explanation = explainer(
|
explanation = explainer(
|
||||||
x=batch.x,
|
x=batch.x,
|
||||||
|
@ -117,29 +117,29 @@ for kind in ["graph"]:
|
||||||
# explanation.__setattr__(
|
# explanation.__setattr__(
|
||||||
# "model_prediction", explainer.get_prediction(x, edge_index)
|
# "model_prediction", explainer.get_prediction(x, edge_index)
|
||||||
# )
|
# )
|
||||||
explanation_threshold = explanation._apply_masks(
|
# explanation_threshold = explanation._apply_masks(
|
||||||
node_mask=torch.ones_like(explanation.node_mask).bool()
|
# node_mask=torch.ones_like(explanation.node_mask).bool()
|
||||||
)
|
# )
|
||||||
|
|
||||||
print(explanation_threshold.__dict__)
|
# print(explanation_threshold.__dict__)
|
||||||
|
|
||||||
for f_name in [
|
# for f_name in [
|
||||||
"gaussian_noise",
|
# "gaussian_noise",
|
||||||
"add_edge",
|
# "add_edge",
|
||||||
"remove_edge",
|
# "remove_edge",
|
||||||
"remove_node",
|
# "remove_node",
|
||||||
"pgd",
|
# "pgd",
|
||||||
"fgsm",
|
# "fgsm",
|
||||||
]:
|
# ]:
|
||||||
print(f_name)
|
# print(f_name)
|
||||||
acc = Attack(name=f_name, model=model, loss=loss)
|
# acc = Attack(name=f_name, model=model, loss=loss)
|
||||||
# gt = torch.ones_like(explanation_threshold.node_mask) / 2
|
# # gt = torch.ones_like(explanation_threshold.node_mask) / 2
|
||||||
# mask = explanation_threshold.node_mask.bool()
|
# # mask = explanation_threshold.node_mask.bool()
|
||||||
# target = (1 - gt).bool()
|
# # target = (1 - gt).bool()
|
||||||
# target[1] = False
|
# # target[1] = False
|
||||||
# print(mask, target)
|
# # print(mask, target)
|
||||||
out = acc.forward(explanation)
|
# out = acc.forward(explanation)
|
||||||
print(out)
|
# print(out)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
@ -39,14 +39,9 @@ class Metric(ABC):
|
||||||
**kwargs (optional): Additional keyword arguments passed to the
|
**kwargs (optional): Additional keyword arguments passed to the
|
||||||
model.
|
model.
|
||||||
"""
|
"""
|
||||||
training = self.model.training
|
print(args, kwargs)
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = self.model(*args, **kwargs)
|
out = self.model(*args, **kwargs)[0]
|
||||||
|
|
||||||
self.model.train(training)
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
from torch.nn import KLDivLoss, Softmax
|
from torch.nn import KLDivLoss, Softmax
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
from torch_geometric.graphgym.config import cfg
|
from torch_geometric.graphgym.config import cfg
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from explaining_framework.metric.base import Metric
|
from explaining_framework.metric.base import Metric
|
||||||
|
|
||||||
|
@ -118,9 +118,19 @@ class Fidelity(Metric):
|
||||||
)
|
)
|
||||||
pos_fidelity = self._fidelity_plus_prob(exp)
|
pos_fidelity = self._fidelity_plus_prob(exp)
|
||||||
neg_fidelity = self._fidelity_minus_prob(exp)
|
neg_fidelity = self._fidelity_minus_prob(exp)
|
||||||
|
if (
|
||||||
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
pos_fidelity == 0
|
||||||
return 1.0 / denom
|
or pos_fidelity == 1
|
||||||
|
or neg_fidelity == 0
|
||||||
|
or neg_fidelity == 1
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||||
|
if denom == 0:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return 1.0 / denom
|
||||||
|
|
||||||
def _characterization(
|
def _characterization(
|
||||||
self,
|
self,
|
||||||
|
@ -136,8 +146,19 @@ class Fidelity(Metric):
|
||||||
pos_fidelity = self._fidelity_plus(exp)
|
pos_fidelity = self._fidelity_plus(exp)
|
||||||
neg_fidelity = self._fidelity_minus(exp)
|
neg_fidelity = self._fidelity_minus(exp)
|
||||||
|
|
||||||
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
if (
|
||||||
return 1.0 / denom
|
pos_fidelity == 0
|
||||||
|
or pos_fidelity == 1
|
||||||
|
or neg_fidelity == 0
|
||||||
|
or neg_fidelity == 1
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||||
|
if denom == 0:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return 1.0 / denom
|
||||||
|
|
||||||
def score(self, exp):
|
def score(self, exp):
|
||||||
self.exp_sub = exp.get_explanation_subgraph()
|
self.exp_sub = exp.get_explanation_subgraph()
|
||||||
|
|
|
@ -26,7 +26,7 @@ class FGSM(Metric):
|
||||||
lower_bound: float = float("-inf"),
|
lower_bound: float = float("-inf"),
|
||||||
upper_bound: float = float("inf"),
|
upper_bound: float = float("inf"),
|
||||||
):
|
):
|
||||||
super().__init__(name=name, model=model)
|
super().__init__(name="fgsm", model=model)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.lower_bound = lower_bound
|
self.lower_bound = lower_bound
|
||||||
|
@ -51,6 +51,9 @@ class FGSM(Metric):
|
||||||
)
|
)
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
def load_metric(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PGD(Metric):
|
class PGD(Metric):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -60,7 +63,7 @@ class PGD(Metric):
|
||||||
lower_bound: float = float("-inf"),
|
lower_bound: float = float("-inf"),
|
||||||
upper_bound: float = float("inf"),
|
upper_bound: float = float("inf"),
|
||||||
):
|
):
|
||||||
super().__init__(name=name, model=model)
|
super().__init__(name="pgd", model=model)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.lower_bound = lower_bound
|
self.lower_bound = lower_bound
|
||||||
|
@ -105,6 +108,9 @@ class PGD(Metric):
|
||||||
perturbed_input.x = self.bound(perturbed_input.x).detach()
|
perturbed_input.x = self.bound(perturbed_input.x).detach()
|
||||||
return perturbed_input
|
return perturbed_input
|
||||||
|
|
||||||
|
def load_metric(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def _random_point(
|
def _random_point(
|
||||||
self, center: torch.Tensor, radius: float, norm: str
|
self, center: torch.Tensor, radius: float, norm: str
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
@ -135,6 +141,7 @@ class Attack(Metric):
|
||||||
dropout: float = 0.5,
|
dropout: float = 0.5,
|
||||||
loss: torch.nn = None,
|
loss: torch.nn = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(name=name, model=model)
|
super().__init__(name=name, model=model)
|
||||||
self.name = name
|
self.name = name
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -148,12 +155,12 @@ class Attack(Metric):
|
||||||
]
|
]
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
if loss is None:
|
if loss is None:
|
||||||
if cfg.model.loss_fun == "cross-entropy":
|
if cfg.model.loss_fun == "cross_entropy":
|
||||||
self.loss = CrossEntropyLoss()
|
self.loss = CrossEntropyLoss()
|
||||||
if cfg.model.loss_fun == "mse":
|
elif cfg.model.loss_fun == "mse":
|
||||||
self.loss = MSELoss()
|
self.loss = MSELoss()
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError(f"{loss} is not supported yet")
|
||||||
else:
|
else:
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.load_metric(name)
|
self.load_metric(name)
|
||||||
|
|
|
@ -19,7 +19,7 @@ class Sparsity(Metric):
|
||||||
|
|
||||||
def forward(self, exp: Explanation) -> float:
|
def forward(self, exp: Explanation) -> float:
|
||||||
out = {}
|
out = {}
|
||||||
for k, v in exp.to_dict():
|
for k, v in exp.to_dict().items():
|
||||||
if "mask" in k and v.dtype == torch.bool:
|
if "mask" in k and v.dtype == torch.bool:
|
||||||
out[k] = torch.mean(mask.float()).item()
|
out[k] = torch.mean(mask.float()).item()
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import copy
|
||||||
import types
|
import types
|
||||||
from inspect import getmembers, isfunction, signature
|
from inspect import getmembers, isfunction, signature
|
||||||
|
|
||||||
|
@ -152,7 +153,7 @@ class GraphStat(object):
|
||||||
return maps
|
return maps
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
data_ = data.__copy__()
|
data_ = copy.copy(data)
|
||||||
datahash = hash(data.__repr__)
|
datahash = hash(data.__repr__)
|
||||||
stats = {}
|
stats = {}
|
||||||
for k, v in self.maps.items():
|
for k, v in self.maps.items():
|
||||||
|
@ -160,7 +161,7 @@ class GraphStat(object):
|
||||||
_data_ = to_networkx(data)
|
_data_ = to_networkx(data)
|
||||||
_data_ = _data_.to_undirected()
|
_data_ = _data_.to_undirected()
|
||||||
elif k == "torch_geometric":
|
elif k == "torch_geometric":
|
||||||
_data_ = data.__copy__()
|
_data_ = copy.copy(data)
|
||||||
for name, func in v.items():
|
for name, func in v.items():
|
||||||
try:
|
try:
|
||||||
val = func(_data_)
|
val = func(_data_)
|
||||||
|
|
|
@ -122,7 +122,7 @@ class LoadModelInfo(object):
|
||||||
if self.info is None:
|
if self.info is None:
|
||||||
self.set_info()
|
self.set_info()
|
||||||
|
|
||||||
model_name = os.path.basename(self.info["xp_dir_name"])
|
model_name = os.path.basename(self.info["xp_dir_path"])
|
||||||
model_seed = self.info["seed"]
|
model_seed = self.info["seed"]
|
||||||
epoch = os.path.basename(self.info["ckpt_path"])
|
epoch = os.path.basename(self.info["ckpt_path"])
|
||||||
model_signature = "-".join(
|
model_signature = "-".join(
|
||||||
|
|
|
@ -4,12 +4,12 @@ from typing import Any
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
from scgnn.scgnn import SCGNN
|
from scgnn.scgnn import SCGNN
|
||||||
from torch_geometric.data import Batch, Data
|
from torch_geometric.data import Batch, Data
|
||||||
from torch_geometric.loader.dataloader import DataLoader
|
|
||||||
from torch_geometric.explain import Explainer
|
from torch_geometric.explain import Explainer
|
||||||
from torch_geometric.graphgym.config import cfg
|
from torch_geometric.graphgym.config import cfg
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
from torch_geometric.loader.dataloader import DataLoader
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
|
@ -53,6 +53,7 @@ all__graphxai = [
|
||||||
"RandomExplainer",
|
"RandomExplainer",
|
||||||
"SubgraphX",
|
"SubgraphX",
|
||||||
"GraphMASK",
|
"GraphMASK",
|
||||||
|
"GNNExplainer",
|
||||||
]
|
]
|
||||||
|
|
||||||
all__own = ["EIXGNN", "SCGNN"]
|
all__own = ["EIXGNN", "SCGNN"]
|
||||||
|
@ -82,6 +83,7 @@ all_robust = [
|
||||||
"pgd",
|
"pgd",
|
||||||
"fgsm",
|
"fgsm",
|
||||||
]
|
]
|
||||||
|
all_sparsity = ["l0"]
|
||||||
|
|
||||||
|
|
||||||
class ExplainingOutline(object):
|
class ExplainingOutline(object):
|
||||||
|
@ -108,6 +110,7 @@ class ExplainingOutline(object):
|
||||||
self.load_explainer()
|
self.load_explainer()
|
||||||
self.load_metric()
|
self.load_metric()
|
||||||
self.load_attack()
|
self.load_attack()
|
||||||
|
self.load_dataset_to_dataloader()
|
||||||
|
|
||||||
def load_model_info(self):
|
def load_model_info(self):
|
||||||
info = LoadModelInfo(
|
info = LoadModelInfo(
|
||||||
|
@ -171,7 +174,12 @@ class ExplainingOutline(object):
|
||||||
if isinstance(self.explaining_cfg.dataset.specific_items, int):
|
if isinstance(self.explaining_cfg.dataset.specific_items, int):
|
||||||
ind = self.explaining_cfg.dataset.specific_items
|
ind = self.explaining_cfg.dataset.specific_items
|
||||||
self.dataset = self.dataset[ind : ind + 1]
|
self.dataset = self.dataset[ind : ind + 1]
|
||||||
self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1)
|
elif isinstance(self.explaining_cfg.dataset.specific_items, list):
|
||||||
|
ind = self.explaining_cfg.dataset.specific_items
|
||||||
|
self.dataset = self.dataset[ind]
|
||||||
|
|
||||||
|
def load_dataset_to_dataloader(self):
|
||||||
|
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
||||||
|
|
||||||
def load_explainer(self):
|
def load_explainer(self):
|
||||||
self.load_explainer_cfg()
|
self.load_explainer_cfg()
|
||||||
|
@ -217,18 +225,20 @@ class ExplainingOutline(object):
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
|
|
||||||
name_ = self.explaining_cfg.metrics.type
|
name_ = self.explaining_cfg.metrics.name
|
||||||
|
|
||||||
if name_ == "all":
|
if name_ == "all":
|
||||||
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
|
all_fid_metrics = [
|
||||||
|
Fidelity(name=name, model=self.model) for name in all_fidelity
|
||||||
|
]
|
||||||
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
||||||
self.metrics = all_acc_metrics + all_fid_metrics
|
self.metrics = all_spa_metrics + all_fid_metrics
|
||||||
|
|
||||||
if self.explaining_cfg.dataset.name == "BASHAPES":
|
if self.explaining_cfg.dataset.name == "BASHAPES":
|
||||||
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
||||||
self.metrics = self.metrics + all_acc_metrics
|
self.metrics = self.metrics + all_acc_metrics
|
||||||
elif name_ in all_fidelity:
|
elif name_ in all_fidelity:
|
||||||
self.metrics = [Fidelity(name_)]
|
self.metrics = [Fidelity(name=name_, model=self.model)]
|
||||||
elif name_ in all_sparsity:
|
elif name_ in all_sparsity:
|
||||||
self.metrics = [Sparsity(name_)]
|
self.metrics = [Sparsity(name_)]
|
||||||
elif name_ in all_accuracy:
|
elif name_ in all_accuracy:
|
||||||
|
@ -250,11 +260,13 @@ class ExplainingOutline(object):
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
name_ = self.explaining_cfg.attack.name
|
name_ = self.explaining_cfg.attack.name
|
||||||
if name_ == "all":
|
if name_ == "all":
|
||||||
all_rob_metrics = [Attack(name) for name in all_robust]
|
all_rob_metrics = [
|
||||||
|
Attack(name=name, model=self.model) for name in all_robust
|
||||||
|
]
|
||||||
self.attacks = all_rob_metrics
|
self.attacks = all_rob_metrics
|
||||||
elif name_ in all_robust:
|
elif name_ in all_robust:
|
||||||
self.attacks = [Attack(name_)]
|
self.attacks = [Attack(name=name_, model=self.model)]
|
||||||
elif name_ is None:
|
elif name_ is None:
|
||||||
slef.attacks = []
|
self.attacks = []
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{name_} is an Attack method that is not supported yet")
|
raise ValueError(f"{name_} is an Attack method that is not supported yet")
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import FloatTensor
|
from torch import FloatTensor
|
||||||
from torch.nn import ReLU
|
from torch.nn import ReLU
|
||||||
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
|
||||||
|
|
||||||
class Adjust(object):
|
class Adjust(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -20,7 +23,7 @@ class Adjust(object):
|
||||||
self.apply_relu = False
|
self.apply_relu = False
|
||||||
|
|
||||||
def forward(self, exp: Explanation) -> Explanation:
|
def forward(self, exp: Explanation) -> Explanation:
|
||||||
exp_ = exp.copy()
|
exp_ = copy.copy(exp)
|
||||||
_store = exp_.to_dict()
|
_store = exp_.to_dict()
|
||||||
for k, v in _store.items():
|
for k, v in _store.items():
|
||||||
if "mask" in k:
|
if "mask" in k:
|
||||||
|
@ -61,5 +64,7 @@ class Adjust(object):
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def absolute(self, mask: FloatTensor) -> FloatTensor:
|
def absolute(self, mask: FloatTensor) -> FloatTensor:
|
||||||
|
print("######################### MASK")
|
||||||
|
print(mask)
|
||||||
mask_ = torch.abs(mask)
|
mask_ = torch.abs(mask)
|
||||||
return mask_
|
return mask_
|
||||||
|
|
|
@ -2,6 +2,7 @@ import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
|
||||||
|
@ -12,7 +13,7 @@ def explanation_verification(exp: Explanation) -> bool:
|
||||||
for mask in masks:
|
for mask in masks:
|
||||||
is_nan = mask.isnan().any().item()
|
is_nan = mask.isnan().any().item()
|
||||||
is_inf = mask.isinf().any().item()
|
is_inf = mask.isinf().any().item()
|
||||||
is_const = mask.max()==mask.min()
|
is_const = mask.max() == mask.min()
|
||||||
is_ok = exp.validate()
|
is_ok = exp.validate()
|
||||||
if is_nan or is_inf or not is_ok or is_const:
|
if is_nan or is_inf or not is_ok or is_const:
|
||||||
is_good = False
|
is_good = False
|
||||||
|
@ -25,8 +26,10 @@ def explanation_verification(exp: Explanation) -> bool:
|
||||||
def save_explanation(exp: Explanation, path: str) -> None:
|
def save_explanation(exp: Explanation, path: str) -> None:
|
||||||
data = copy.copy(exp).to_dict()
|
data = copy.copy(exp).to_dict()
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
|
print(k, v)
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
data[k] = v.detach().cpu().tolist()
|
data[k] = v.detach().cpu().tolist()
|
||||||
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
json.dump(data, f)
|
json.dump(data, f)
|
||||||
|
|
||||||
|
@ -48,8 +51,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation
|
||||||
data = exp.to_dict()
|
data = exp.to_dict()
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
||||||
norm =torch.norm(input=data[k], p=p, dim=None).item()
|
norm = torch.norm(input=data[k], p=p, dim=None).item()
|
||||||
if norm.item()>0:
|
if norm.item() > 0:
|
||||||
data[k] = data[k] / norm
|
data[k] = data[k] / norm
|
||||||
|
|
||||||
return exp
|
return exp
|
||||||
|
|
|
@ -31,11 +31,8 @@ def is_exists(path: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def get_obj_config(obj):
|
def get_obj_config(obj):
|
||||||
config = {k: getattr(obj, k) for k in dir(obj)}
|
|
||||||
config = {
|
config = {
|
||||||
k: v
|
k: v for k, v in obj.__dict__.items() if isinstance(v, (int, float, str, bool))
|
||||||
for k, v in config.items()
|
|
||||||
if isinstance(v, (int, float, str, bool)) or v is None
|
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
111
main.py
111
main.py
|
@ -2,9 +2,11 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import copy
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch_geometric import seed_everything
|
from torch_geometric import seed_everything
|
||||||
from torch_geometric.data.makedirs import makedirs
|
from torch_geometric.data.makedirs import makedirs
|
||||||
from torch_geometric.explain import Explainer
|
from torch_geometric.explain import Explainer
|
||||||
|
@ -16,27 +18,39 @@ from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||||
from explaining_framework.utils.explanation.adjust import Adjust
|
from explaining_framework.utils.explanation.adjust import Adjust
|
||||||
from explaining_framework.utils.io import (obj_config_to_str, read_json,
|
from explaining_framework.utils.explanation.io import (
|
||||||
write_json, write_yaml)
|
explanation_verification, load_explanation, save_explanation)
|
||||||
|
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||||
|
read_json, write_json, write_yaml)
|
||||||
|
|
||||||
# inference, time, force,
|
# inference, time, force,
|
||||||
|
|
||||||
|
|
||||||
def get_pred(explanation, force=False):
|
def get_pred(explainer, explanation):
|
||||||
dict_ = explanation.to_dict()
|
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[
|
||||||
if dict_.get("pred") is None or dict_.get("pred_masked") or force:
|
0
|
||||||
pred = explainer.get_prediction(explanation)
|
]
|
||||||
|
setattr(explanation, "pred", pred)
|
||||||
|
data = explanation.to_dict()
|
||||||
|
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
|
||||||
pred_masked = explainer.get_masked_prediction(
|
pred_masked = explainer.get_masked_prediction(
|
||||||
x=explanation.x,
|
x=explanation.x,
|
||||||
edge_index=explanation.edge_index,
|
edge_index=explanation.edge_index,
|
||||||
node_mask=explanation.node_mask,
|
node_mask=data.get("node_mask"),
|
||||||
edge_mask=explanation.edge_mask,
|
edge_mask=data.get("edge_mask"),
|
||||||
)
|
)[0]
|
||||||
explanation.__setattr__("pred", pred)
|
setattr(explanation, "pred_exp", pred_masked)
|
||||||
explanation.__setattr__("pred_masked", pred_masked)
|
|
||||||
return explanation
|
|
||||||
else:
|
def get_explanation(explainer, item):
|
||||||
return explanation
|
explanation = explainer(
|
||||||
|
x=item.x,
|
||||||
|
edge_index=item.edge_index,
|
||||||
|
index=int(item.y),
|
||||||
|
target=item.y,
|
||||||
|
)
|
||||||
|
assert explanation_verification(explanation)
|
||||||
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -45,8 +59,9 @@ if __name__ == "__main__":
|
||||||
auto_select_device()
|
auto_select_device()
|
||||||
|
|
||||||
# Load components
|
# Load components
|
||||||
dataset = outline.dataset.to(cfg.accelerator)
|
dataset = outline.dataset
|
||||||
model = outline.model.to(cfg.accelerator)
|
model = outline.model.to(cfg.accelerator)
|
||||||
|
model = model.eval()
|
||||||
model_info = outline.model_info
|
model_info = outline.model_info
|
||||||
metrics = outline.metrics
|
metrics = outline.metrics
|
||||||
explaining_algorithm = outline.explaining_algorithm
|
explaining_algorithm = outline.explaining_algorithm
|
||||||
|
@ -87,53 +102,57 @@ if __name__ == "__main__":
|
||||||
return_type=explaining_cfg.model_config.return_type,
|
return_type=explaining_cfg.model_config.return_type,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if not explaining_cfg.dataset.specific_items is None:
|
||||||
|
indexes = explaining_cfg.dataset.specific_items
|
||||||
|
else:
|
||||||
|
indexes = range(len(dataset))
|
||||||
# Save explaining configuration
|
# Save explaining configuration
|
||||||
for index, item in enumerate(dataset):
|
for index, item in zip(indexes, dataset):
|
||||||
|
item = item.to(cfg.accelerator)
|
||||||
save_raw_path = os.path.join(global_path, "raw")
|
save_raw_path = os.path.join(global_path, "raw")
|
||||||
makedirs(save_raw_path)
|
makedirs(save_raw_path)
|
||||||
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
||||||
|
|
||||||
if is_exists(explanation_path):
|
if is_exists(explanation_path):
|
||||||
if explaining_cfg.explainer.force:
|
if explaining_cfg.explainer.force:
|
||||||
explanation = explainer(
|
explanation = get_explanation(explainer, item)
|
||||||
x=item.x,
|
|
||||||
edge_index=item.edge_index,
|
|
||||||
index=item.y,
|
|
||||||
target=item.y,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
explanation = load_explanation(explanation_path)
|
explanation = load_explanation(explanation_path)
|
||||||
else:
|
else:
|
||||||
explanation = explainer(
|
explanation = get_explanation(explainer, item)
|
||||||
x=item.x,
|
|
||||||
edge_index=item.edge_index,
|
explanation = explanation.to(cfg.accelerator)
|
||||||
index=item.y,
|
get_pred(explainer=explainer, explanation=explanation)
|
||||||
target=item.y,
|
|
||||||
)
|
|
||||||
explanation = get_pred(explanation, force=False)
|
|
||||||
save_explanation(explanation, explanation_path)
|
save_explanation(explanation, explanation_path)
|
||||||
for apply_relu in [True, False]:
|
for apply_relu in [True, False]:
|
||||||
for apply_absolute in [True, False]:
|
for apply_absolute in [True, False]:
|
||||||
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
|
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
|
||||||
save_raw_path = os.path.join(
|
save_raw_path_ = os.path.join(
|
||||||
global_path, f"adjust-{obj_config_to_str(adjust)}"
|
global_path, f"adjust-{obj_config_to_str(adjust)}"
|
||||||
)
|
)
|
||||||
makedirs(save_raw_path)
|
explanation__ = copy.copy(explanation).to(cfg.accelerator)
|
||||||
explanation = adjust.forward(explanation)
|
makedirs(save_raw_path_)
|
||||||
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
explanation = adjust.forward(explanation__)
|
||||||
explanation = get_pred(explanation, force=True)
|
explanation_path = os.path.join(save_raw_path_, f"{index}.json")
|
||||||
save_explanation(explanation, explanation_path)
|
get_pred(explainer, explanation__)
|
||||||
|
save_explanation(explanation__, explanation_path)
|
||||||
|
|
||||||
for threshold_approach in ["hard", "topk", "topk_hard"]:
|
for threshold_approach in ["hard", "topk", "topk_hard"]:
|
||||||
for threshold_value in explaining_cfg.threshold_config.value:
|
if threshold_approach == "hard":
|
||||||
|
threshold_values = explaining_cfg.threshold_config.value
|
||||||
|
elif "topk" in threshold_approach:
|
||||||
|
threshold_values = [3, 5, 10, 20]
|
||||||
|
for threshold_value in threshold_values:
|
||||||
|
|
||||||
masking_path = os.path.join(
|
masking_path = os.path.join(
|
||||||
save_raw_path,
|
save_raw_path_,
|
||||||
f"threshold={threshold_approach}-value={value}",
|
f"threshold={threshold_approach}-value={threshold_value}",
|
||||||
)
|
)
|
||||||
|
makedirs(masking_path)
|
||||||
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
|
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
|
||||||
if is_exists(exp_threshold_path):
|
if is_exists(exp_threshold_path):
|
||||||
explanation = load_explanation(exp_threshold_path)
|
exp_threshold = load_explanation(exp_threshold_path)
|
||||||
else:
|
else:
|
||||||
threshold_conf = {
|
threshold_conf = {
|
||||||
"threshold_type": threshold_approach,
|
"threshold_type": threshold_approach,
|
||||||
|
@ -143,17 +162,21 @@ if __name__ == "__main__":
|
||||||
threshold_conf
|
threshold_conf
|
||||||
)
|
)
|
||||||
|
|
||||||
expl = copy.copy(explanation)
|
expl = copy.copy(explanation__).to(cfg.accelerator)
|
||||||
exp_threshold = explainer._post_process(expl)
|
exp_threshold = explainer._post_process(expl)
|
||||||
exp_threshold = get_pred(exp_threshold, force=True)
|
exp_threshold = exp_threshold.to(cfg.accelerator)
|
||||||
|
get_pred(explainer, exp_threshold)
|
||||||
save_explanation(exp_threshold, exp_threshold_path)
|
save_explanation(exp_threshold, exp_threshold_path)
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
metric_path = os.path.join(
|
metric_path = os.path.join(
|
||||||
masking_path, f"{obj_config_to_str(metric)}"
|
masking_path, f"{obj_config_to_str(metric)}"
|
||||||
)
|
)
|
||||||
|
makedirs(metric_path)
|
||||||
if is_exists(os.path.join(metric_path, f"{index}.json")):
|
if is_exists(os.path.join(metric_path, f"{index}.json")):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
out = metric.forward(exp_threshold)
|
out = metric.forward(exp_threshold)
|
||||||
write_json({f"{metric.name}": out})
|
write_json(
|
||||||
|
{f"{metric.name}": out},
|
||||||
|
os.path.join(metric_path, f"{index}.json"),
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue