Fixing bugs and adding new features

This commit is contained in:
araison 2023-01-02 23:37:40 +01:00
parent db0be0ddb7
commit 9ad5adb33e
13 changed files with 183 additions and 120 deletions

View File

@ -114,7 +114,7 @@ def set_cfg(explaining_cfg):
explaining_cfg.threshold_config.threshold_type = None
explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7]
explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)]
explaining_cfg.threshold_config.relu_and_normalize = True
@ -128,8 +128,7 @@ def set_cfg(explaining_cfg):
explaining_cfg.metrics.force = False
explaining_cfg.attack = CN()
explaining_cfg.attack.name = 'all'
explaining_cfg.attack.name = "all"
explaining_cfg.accelerator = "auto"

View File

@ -31,16 +31,16 @@ __all__captum = [
__all__graphxai = [
"CAM",
# "GradCAM",
# "GNN_LRP",
# "GradExplainer",
# "GuidedBackPropagation",
# "IntegratedGradients",
# "PGExplainer",
# "PGMExplainer",
# "RandomExplainer",
# "SubgraphX",
# "GraphMASK",
"GradCAM",
"GNN_LRP",
"GradExplainer",
"GuidedBackPropagation",
"IntegratedGradients",
"PGExplainer",
"PGMExplainer",
"RandomExplainer",
"SubgraphX",
"GraphMASK",
]
@ -82,12 +82,12 @@ for epoch in range(1, 2):
target = torch.LongTensor([[0]])
for kind in ["graph"]:
for name in __all__graphxai:
for name in __all__graphxai + __all__captum:
if name in __all__captum:
explaining_algorithm = CaptumWrapper(name)
elif name in __all__graphxai:
explaining_algorithm = GraphXAIWrapper(
name, in_channels=in_channels, criterion="cross-entropy"
name, in_channels=in_channels, criterion="cross_entropy"
)
print(name)
@ -105,7 +105,7 @@ for kind in ["graph"]:
task_level=kind,
return_type="raw",
),
threshold_config=dict(threshold_type="hard", value=0.5),
# threshold_config=dict(threshold_type=None, value=0.5),
)
explanation = explainer(
x=batch.x,
@ -117,29 +117,29 @@ for kind in ["graph"]:
# explanation.__setattr__(
# "model_prediction", explainer.get_prediction(x, edge_index)
# )
explanation_threshold = explanation._apply_masks(
node_mask=torch.ones_like(explanation.node_mask).bool()
)
# explanation_threshold = explanation._apply_masks(
# node_mask=torch.ones_like(explanation.node_mask).bool()
# )
print(explanation_threshold.__dict__)
# print(explanation_threshold.__dict__)
for f_name in [
"gaussian_noise",
"add_edge",
"remove_edge",
"remove_node",
"pgd",
"fgsm",
]:
print(f_name)
acc = Attack(name=f_name, model=model, loss=loss)
# gt = torch.ones_like(explanation_threshold.node_mask) / 2
# mask = explanation_threshold.node_mask.bool()
# target = (1 - gt).bool()
# target[1] = False
# print(mask, target)
out = acc.forward(explanation)
print(out)
# for f_name in [
# "gaussian_noise",
# "add_edge",
# "remove_edge",
# "remove_node",
# "pgd",
# "fgsm",
# ]:
# print(f_name)
# acc = Attack(name=f_name, model=model, loss=loss)
# # gt = torch.ones_like(explanation_threshold.node_mask) / 2
# # mask = explanation_threshold.node_mask.bool()
# # target = (1 - gt).bool()
# # target[1] = False
# # print(mask, target)
# out = acc.forward(explanation)
# print(out)
except Exception as e:
traceback.print_exc()

View File

@ -39,14 +39,9 @@ class Metric(ABC):
**kwargs (optional): Additional keyword arguments passed to the
model.
"""
training = self.model.training
self.model.eval()
print(args, kwargs)
with torch.no_grad():
out = self.model(*args, **kwargs)
self.model.train(training)
out = self.model(*args, **kwargs)[0]
return out

View File

@ -1,9 +1,9 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import KLDivLoss, Softmax
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
from torch import Tensor
from explaining_framework.metric.base import Metric
@ -118,9 +118,19 @@ class Fidelity(Metric):
)
pos_fidelity = self._fidelity_plus_prob(exp)
neg_fidelity = self._fidelity_minus_prob(exp)
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
return 1.0 / denom
if (
pos_fidelity == 0
or pos_fidelity == 1
or neg_fidelity == 0
or neg_fidelity == 1
):
return None
else:
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
if denom == 0:
return None
else:
return 1.0 / denom
def _characterization(
self,
@ -136,8 +146,19 @@ class Fidelity(Metric):
pos_fidelity = self._fidelity_plus(exp)
neg_fidelity = self._fidelity_minus(exp)
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
return 1.0 / denom
if (
pos_fidelity == 0
or pos_fidelity == 1
or neg_fidelity == 0
or neg_fidelity == 1
):
return None
else:
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
if denom == 0:
return None
else:
return 1.0 / denom
def score(self, exp):
self.exp_sub = exp.get_explanation_subgraph()

View File

@ -26,7 +26,7 @@ class FGSM(Metric):
lower_bound: float = float("-inf"),
upper_bound: float = float("inf"),
):
super().__init__(name=name, model=model)
super().__init__(name="fgsm", model=model)
self.model = model
self.loss = loss
self.lower_bound = lower_bound
@ -51,6 +51,9 @@ class FGSM(Metric):
)
return input_
def load_metric(self):
pass
class PGD(Metric):
def __init__(
@ -60,7 +63,7 @@ class PGD(Metric):
lower_bound: float = float("-inf"),
upper_bound: float = float("inf"),
):
super().__init__(name=name, model=model)
super().__init__(name="pgd", model=model)
self.model = model
self.loss = loss
self.lower_bound = lower_bound
@ -105,6 +108,9 @@ class PGD(Metric):
perturbed_input.x = self.bound(perturbed_input.x).detach()
return perturbed_input
def load_metric(self):
pass
def _random_point(
self, center: torch.Tensor, radius: float, norm: str
) -> torch.Tensor:
@ -135,6 +141,7 @@ class Attack(Metric):
dropout: float = 0.5,
loss: torch.nn = None,
):
super().__init__(name=name, model=model)
self.name = name
self.model = model
@ -148,12 +155,12 @@ class Attack(Metric):
]
self.dropout = dropout
if loss is None:
if cfg.model.loss_fun == "cross-entropy":
if cfg.model.loss_fun == "cross_entropy":
self.loss = CrossEntropyLoss()
if cfg.model.loss_fun == "mse":
elif cfg.model.loss_fun == "mse":
self.loss = MSELoss()
else:
raise ValueError
raise ValueError(f"{loss} is not supported yet")
else:
self.loss = loss
self.load_metric(name)

View File

@ -19,7 +19,7 @@ class Sparsity(Metric):
def forward(self, exp: Explanation) -> float:
out = {}
for k, v in exp.to_dict():
for k, v in exp.to_dict().items():
if "mask" in k and v.dtype == torch.bool:
out[k] = torch.mean(mask.float()).item()
return out

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import types
from inspect import getmembers, isfunction, signature
@ -152,7 +153,7 @@ class GraphStat(object):
return maps
def __call__(self, data):
data_ = data.__copy__()
data_ = copy.copy(data)
datahash = hash(data.__repr__)
stats = {}
for k, v in self.maps.items():
@ -160,7 +161,7 @@ class GraphStat(object):
_data_ = to_networkx(data)
_data_ = _data_.to_undirected()
elif k == "torch_geometric":
_data_ = data.__copy__()
_data_ = copy.copy(data)
for name, func in v.items():
try:
val = func(_data_)

View File

@ -122,7 +122,7 @@ class LoadModelInfo(object):
if self.info is None:
self.set_info()
model_name = os.path.basename(self.info["xp_dir_name"])
model_name = os.path.basename(self.info["xp_dir_path"])
model_seed = self.info["seed"]
epoch = os.path.basename(self.info["ckpt_path"])
model_signature = "-".join(

View File

@ -4,12 +4,12 @@ from typing import Any
from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric.data import Batch, Data
from torch_geometric.loader.dataloader import DataLoader
from torch_geometric.explain import Explainer
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg
@ -53,6 +53,7 @@ all__graphxai = [
"RandomExplainer",
"SubgraphX",
"GraphMASK",
"GNNExplainer",
]
all__own = ["EIXGNN", "SCGNN"]
@ -82,6 +83,7 @@ all_robust = [
"pgd",
"fgsm",
]
all_sparsity = ["l0"]
class ExplainingOutline(object):
@ -108,6 +110,7 @@ class ExplainingOutline(object):
self.load_explainer()
self.load_metric()
self.load_attack()
self.load_dataset_to_dataloader()
def load_model_info(self):
info = LoadModelInfo(
@ -171,7 +174,12 @@ class ExplainingOutline(object):
if isinstance(self.explaining_cfg.dataset.specific_items, int):
ind = self.explaining_cfg.dataset.specific_items
self.dataset = self.dataset[ind : ind + 1]
self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1)
elif isinstance(self.explaining_cfg.dataset.specific_items, list):
ind = self.explaining_cfg.dataset.specific_items
self.dataset = self.dataset[ind]
def load_dataset_to_dataloader(self):
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
def load_explainer(self):
self.load_explainer_cfg()
@ -217,18 +225,20 @@ class ExplainingOutline(object):
if self.explaining_cfg is None:
self.load_explaining_cfg()
name_ = self.explaining_cfg.metrics.type
name_ = self.explaining_cfg.metrics.name
if name_ == "all":
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
all_fid_metrics = [
Fidelity(name=name, model=self.model) for name in all_fidelity
]
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
self.metrics = all_acc_metrics + all_fid_metrics
self.metrics = all_spa_metrics + all_fid_metrics
if self.explaining_cfg.dataset.name == "BASHAPES":
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
self.metrics = self.metrics + all_acc_metrics
elif name_ in all_fidelity:
self.metrics = [Fidelity(name_)]
self.metrics = [Fidelity(name=name_, model=self.model)]
elif name_ in all_sparsity:
self.metrics = [Sparsity(name_)]
elif name_ in all_accuracy:
@ -250,11 +260,13 @@ class ExplainingOutline(object):
self.load_explaining_cfg()
name_ = self.explaining_cfg.attack.name
if name_ == "all":
all_rob_metrics = [Attack(name) for name in all_robust]
all_rob_metrics = [
Attack(name=name, model=self.model) for name in all_robust
]
self.attacks = all_rob_metrics
elif name_ in all_robust:
self.attacks = [Attack(name_)]
self.attacks = [Attack(name=name_, model=self.model)]
elif name_ is None:
slef.attacks = []
self.attacks = []
else:
raise ValueError(f"{name_} is an Attack method that is not supported yet")

View File

@ -1,7 +1,10 @@
import copy
import torch
from torch import FloatTensor
from torch.nn import ReLU
from torch_geometric.explain.explanation import Explanation
class Adjust(object):
def __init__(
@ -20,7 +23,7 @@ class Adjust(object):
self.apply_relu = False
def forward(self, exp: Explanation) -> Explanation:
exp_ = exp.copy()
exp_ = copy.copy(exp)
_store = exp_.to_dict()
for k, v in _store.items():
if "mask" in k:
@ -61,5 +64,7 @@ class Adjust(object):
return mask
def absolute(self, mask: FloatTensor) -> FloatTensor:
print("######################### MASK")
print(mask)
mask_ = torch.abs(mask)
return mask_

View File

@ -2,6 +2,7 @@ import copy
import json
import os
import torch
from torch_geometric.data import Data
from torch_geometric.explain.explanation import Explanation
@ -12,7 +13,7 @@ def explanation_verification(exp: Explanation) -> bool:
for mask in masks:
is_nan = mask.isnan().any().item()
is_inf = mask.isinf().any().item()
is_const = mask.max()==mask.min()
is_const = mask.max() == mask.min()
is_ok = exp.validate()
if is_nan or is_inf or not is_ok or is_const:
is_good = False
@ -25,8 +26,10 @@ def explanation_verification(exp: Explanation) -> bool:
def save_explanation(exp: Explanation, path: str) -> None:
data = copy.copy(exp).to_dict()
for k, v in data.items():
print(k, v)
if isinstance(v, torch.Tensor):
data[k] = v.detach().cpu().tolist()
with open(path, "w") as f:
json.dump(data, f)
@ -48,8 +51,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation
data = exp.to_dict()
for k, v in data.items():
if "_mask" in k and isinstance(v, torch.FloatTensor):
norm =torch.norm(input=data[k], p=p, dim=None).item()
if norm.item()>0:
norm = torch.norm(input=data[k], p=p, dim=None).item()
if norm.item() > 0:
data[k] = data[k] / norm
return exp

View File

@ -31,11 +31,8 @@ def is_exists(path: str) -> bool:
def get_obj_config(obj):
config = {k: getattr(obj, k) for k in dir(obj)}
config = {
k: v
for k, v in config.items()
if isinstance(v, (int, float, str, bool)) or v is None
k: v for k, v in obj.__dict__.items() if isinstance(v, (int, float, str, bool))
}
return config

111
main.py
View File

@ -2,9 +2,11 @@
# -*- coding: utf-8 -*-
#
import copy
import os
import time
import torch
from torch_geometric import seed_everything
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
@ -16,27 +18,39 @@ from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline
from explaining_framework.utils.explanation.adjust import Adjust
from explaining_framework.utils.io import (obj_config_to_str, read_json,
write_json, write_yaml)
from explaining_framework.utils.explanation.io import (
explanation_verification, load_explanation, save_explanation)
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
read_json, write_json, write_yaml)
# inference, time, force,
def get_pred(explanation, force=False):
dict_ = explanation.to_dict()
if dict_.get("pred") is None or dict_.get("pred_masked") or force:
pred = explainer.get_prediction(explanation)
def get_pred(explainer, explanation):
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[
0
]
setattr(explanation, "pred", pred)
data = explanation.to_dict()
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
pred_masked = explainer.get_masked_prediction(
x=explanation.x,
edge_index=explanation.edge_index,
node_mask=explanation.node_mask,
edge_mask=explanation.edge_mask,
)
explanation.__setattr__("pred", pred)
explanation.__setattr__("pred_masked", pred_masked)
return explanation
else:
return explanation
node_mask=data.get("node_mask"),
edge_mask=data.get("edge_mask"),
)[0]
setattr(explanation, "pred_exp", pred_masked)
def get_explanation(explainer, item):
explanation = explainer(
x=item.x,
edge_index=item.edge_index,
index=int(item.y),
target=item.y,
)
assert explanation_verification(explanation)
return explanation
if __name__ == "__main__":
@ -45,8 +59,9 @@ if __name__ == "__main__":
auto_select_device()
# Load components
dataset = outline.dataset.to(cfg.accelerator)
dataset = outline.dataset
model = outline.model.to(cfg.accelerator)
model = model.eval()
model_info = outline.model_info
metrics = outline.metrics
explaining_algorithm = outline.explaining_algorithm
@ -87,53 +102,57 @@ if __name__ == "__main__":
return_type=explaining_cfg.model_config.return_type,
),
)
if not explaining_cfg.dataset.specific_items is None:
indexes = explaining_cfg.dataset.specific_items
else:
indexes = range(len(dataset))
# Save explaining configuration
for index, item in enumerate(dataset):
for index, item in zip(indexes, dataset):
item = item.to(cfg.accelerator)
save_raw_path = os.path.join(global_path, "raw")
makedirs(save_raw_path)
explanation_path = os.path.join(save_raw_path, f"{index}.json")
if is_exists(explanation_path):
if explaining_cfg.explainer.force:
explanation = explainer(
x=item.x,
edge_index=item.edge_index,
index=item.y,
target=item.y,
)
explanation = get_explanation(explainer, item)
else:
explanation = load_explanation(explanation_path)
else:
explanation = explainer(
x=item.x,
edge_index=item.edge_index,
index=item.y,
target=item.y,
)
explanation = get_pred(explanation, force=False)
explanation = get_explanation(explainer, item)
explanation = explanation.to(cfg.accelerator)
get_pred(explainer=explainer, explanation=explanation)
save_explanation(explanation, explanation_path)
for apply_relu in [True, False]:
for apply_absolute in [True, False]:
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
save_raw_path = os.path.join(
save_raw_path_ = os.path.join(
global_path, f"adjust-{obj_config_to_str(adjust)}"
)
makedirs(save_raw_path)
explanation = adjust.forward(explanation)
explanation_path = os.path.join(save_raw_path, f"{index}.json")
explanation = get_pred(explanation, force=True)
save_explanation(explanation, explanation_path)
explanation__ = copy.copy(explanation).to(cfg.accelerator)
makedirs(save_raw_path_)
explanation = adjust.forward(explanation__)
explanation_path = os.path.join(save_raw_path_, f"{index}.json")
get_pred(explainer, explanation__)
save_explanation(explanation__, explanation_path)
for threshold_approach in ["hard", "topk", "topk_hard"]:
for threshold_value in explaining_cfg.threshold_config.value:
if threshold_approach == "hard":
threshold_values = explaining_cfg.threshold_config.value
elif "topk" in threshold_approach:
threshold_values = [3, 5, 10, 20]
for threshold_value in threshold_values:
masking_path = os.path.join(
save_raw_path,
f"threshold={threshold_approach}-value={value}",
save_raw_path_,
f"threshold={threshold_approach}-value={threshold_value}",
)
makedirs(masking_path)
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
if is_exists(exp_threshold_path):
explanation = load_explanation(exp_threshold_path)
exp_threshold = load_explanation(exp_threshold_path)
else:
threshold_conf = {
"threshold_type": threshold_approach,
@ -143,17 +162,21 @@ if __name__ == "__main__":
threshold_conf
)
expl = copy.copy(explanation)
expl = copy.copy(explanation__).to(cfg.accelerator)
exp_threshold = explainer._post_process(expl)
exp_threshold = get_pred(exp_threshold, force=True)
save_explanation(exp_threshold, exp_threshold_path)
exp_threshold = exp_threshold.to(cfg.accelerator)
get_pred(explainer, exp_threshold)
save_explanation(exp_threshold, exp_threshold_path)
for metric in metrics:
metric_path = os.path.join(
masking_path, f"{obj_config_to_str(metric)}"
)
makedirs(metric_path)
if is_exists(os.path.join(metric_path, f"{index}.json")):
continue
else:
out = metric.forward(exp_threshold)
write_json({f"{metric.name}": out})
write_json(
{f"{metric.name}": out},
os.path.join(metric_path, f"{index}.json"),
)