Fixing bugs and adding new features
This commit is contained in:
parent
db0be0ddb7
commit
9ad5adb33e
|
@ -114,7 +114,7 @@ def set_cfg(explaining_cfg):
|
|||
|
||||
explaining_cfg.threshold_config.threshold_type = None
|
||||
|
||||
explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7]
|
||||
explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)]
|
||||
|
||||
explaining_cfg.threshold_config.relu_and_normalize = True
|
||||
|
||||
|
@ -128,8 +128,7 @@ def set_cfg(explaining_cfg):
|
|||
explaining_cfg.metrics.force = False
|
||||
|
||||
explaining_cfg.attack = CN()
|
||||
explaining_cfg.attack.name = 'all'
|
||||
|
||||
explaining_cfg.attack.name = "all"
|
||||
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
|
|
|
@ -31,16 +31,16 @@ __all__captum = [
|
|||
|
||||
__all__graphxai = [
|
||||
"CAM",
|
||||
# "GradCAM",
|
||||
# "GNN_LRP",
|
||||
# "GradExplainer",
|
||||
# "GuidedBackPropagation",
|
||||
# "IntegratedGradients",
|
||||
# "PGExplainer",
|
||||
# "PGMExplainer",
|
||||
# "RandomExplainer",
|
||||
# "SubgraphX",
|
||||
# "GraphMASK",
|
||||
"GradCAM",
|
||||
"GNN_LRP",
|
||||
"GradExplainer",
|
||||
"GuidedBackPropagation",
|
||||
"IntegratedGradients",
|
||||
"PGExplainer",
|
||||
"PGMExplainer",
|
||||
"RandomExplainer",
|
||||
"SubgraphX",
|
||||
"GraphMASK",
|
||||
]
|
||||
|
||||
|
||||
|
@ -82,12 +82,12 @@ for epoch in range(1, 2):
|
|||
target = torch.LongTensor([[0]])
|
||||
|
||||
for kind in ["graph"]:
|
||||
for name in __all__graphxai:
|
||||
for name in __all__graphxai + __all__captum:
|
||||
if name in __all__captum:
|
||||
explaining_algorithm = CaptumWrapper(name)
|
||||
elif name in __all__graphxai:
|
||||
explaining_algorithm = GraphXAIWrapper(
|
||||
name, in_channels=in_channels, criterion="cross-entropy"
|
||||
name, in_channels=in_channels, criterion="cross_entropy"
|
||||
)
|
||||
|
||||
print(name)
|
||||
|
@ -105,7 +105,7 @@ for kind in ["graph"]:
|
|||
task_level=kind,
|
||||
return_type="raw",
|
||||
),
|
||||
threshold_config=dict(threshold_type="hard", value=0.5),
|
||||
# threshold_config=dict(threshold_type=None, value=0.5),
|
||||
)
|
||||
explanation = explainer(
|
||||
x=batch.x,
|
||||
|
@ -117,29 +117,29 @@ for kind in ["graph"]:
|
|||
# explanation.__setattr__(
|
||||
# "model_prediction", explainer.get_prediction(x, edge_index)
|
||||
# )
|
||||
explanation_threshold = explanation._apply_masks(
|
||||
node_mask=torch.ones_like(explanation.node_mask).bool()
|
||||
)
|
||||
# explanation_threshold = explanation._apply_masks(
|
||||
# node_mask=torch.ones_like(explanation.node_mask).bool()
|
||||
# )
|
||||
|
||||
print(explanation_threshold.__dict__)
|
||||
# print(explanation_threshold.__dict__)
|
||||
|
||||
for f_name in [
|
||||
"gaussian_noise",
|
||||
"add_edge",
|
||||
"remove_edge",
|
||||
"remove_node",
|
||||
"pgd",
|
||||
"fgsm",
|
||||
]:
|
||||
print(f_name)
|
||||
acc = Attack(name=f_name, model=model, loss=loss)
|
||||
# gt = torch.ones_like(explanation_threshold.node_mask) / 2
|
||||
# mask = explanation_threshold.node_mask.bool()
|
||||
# target = (1 - gt).bool()
|
||||
# target[1] = False
|
||||
# print(mask, target)
|
||||
out = acc.forward(explanation)
|
||||
print(out)
|
||||
# for f_name in [
|
||||
# "gaussian_noise",
|
||||
# "add_edge",
|
||||
# "remove_edge",
|
||||
# "remove_node",
|
||||
# "pgd",
|
||||
# "fgsm",
|
||||
# ]:
|
||||
# print(f_name)
|
||||
# acc = Attack(name=f_name, model=model, loss=loss)
|
||||
# # gt = torch.ones_like(explanation_threshold.node_mask) / 2
|
||||
# # mask = explanation_threshold.node_mask.bool()
|
||||
# # target = (1 - gt).bool()
|
||||
# # target[1] = False
|
||||
# # print(mask, target)
|
||||
# out = acc.forward(explanation)
|
||||
# print(out)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -39,14 +39,9 @@ class Metric(ABC):
|
|||
**kwargs (optional): Additional keyword arguments passed to the
|
||||
model.
|
||||
"""
|
||||
training = self.model.training
|
||||
self.model.eval()
|
||||
print(args, kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
out = self.model(*args, **kwargs)
|
||||
|
||||
self.model.train(training)
|
||||
out = self.model(*args, **kwargs)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn import KLDivLoss, Softmax
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch import Tensor
|
||||
|
||||
from explaining_framework.metric.base import Metric
|
||||
|
||||
|
@ -118,9 +118,19 @@ class Fidelity(Metric):
|
|||
)
|
||||
pos_fidelity = self._fidelity_plus_prob(exp)
|
||||
neg_fidelity = self._fidelity_minus_prob(exp)
|
||||
|
||||
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||
return 1.0 / denom
|
||||
if (
|
||||
pos_fidelity == 0
|
||||
or pos_fidelity == 1
|
||||
or neg_fidelity == 0
|
||||
or neg_fidelity == 1
|
||||
):
|
||||
return None
|
||||
else:
|
||||
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||
if denom == 0:
|
||||
return None
|
||||
else:
|
||||
return 1.0 / denom
|
||||
|
||||
def _characterization(
|
||||
self,
|
||||
|
@ -136,8 +146,19 @@ class Fidelity(Metric):
|
|||
pos_fidelity = self._fidelity_plus(exp)
|
||||
neg_fidelity = self._fidelity_minus(exp)
|
||||
|
||||
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||
return 1.0 / denom
|
||||
if (
|
||||
pos_fidelity == 0
|
||||
or pos_fidelity == 1
|
||||
or neg_fidelity == 0
|
||||
or neg_fidelity == 1
|
||||
):
|
||||
return None
|
||||
else:
|
||||
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||
if denom == 0:
|
||||
return None
|
||||
else:
|
||||
return 1.0 / denom
|
||||
|
||||
def score(self, exp):
|
||||
self.exp_sub = exp.get_explanation_subgraph()
|
||||
|
|
|
@ -26,7 +26,7 @@ class FGSM(Metric):
|
|||
lower_bound: float = float("-inf"),
|
||||
upper_bound: float = float("inf"),
|
||||
):
|
||||
super().__init__(name=name, model=model)
|
||||
super().__init__(name="fgsm", model=model)
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lower_bound = lower_bound
|
||||
|
@ -51,6 +51,9 @@ class FGSM(Metric):
|
|||
)
|
||||
return input_
|
||||
|
||||
def load_metric(self):
|
||||
pass
|
||||
|
||||
|
||||
class PGD(Metric):
|
||||
def __init__(
|
||||
|
@ -60,7 +63,7 @@ class PGD(Metric):
|
|||
lower_bound: float = float("-inf"),
|
||||
upper_bound: float = float("inf"),
|
||||
):
|
||||
super().__init__(name=name, model=model)
|
||||
super().__init__(name="pgd", model=model)
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lower_bound = lower_bound
|
||||
|
@ -105,6 +108,9 @@ class PGD(Metric):
|
|||
perturbed_input.x = self.bound(perturbed_input.x).detach()
|
||||
return perturbed_input
|
||||
|
||||
def load_metric(self):
|
||||
pass
|
||||
|
||||
def _random_point(
|
||||
self, center: torch.Tensor, radius: float, norm: str
|
||||
) -> torch.Tensor:
|
||||
|
@ -135,6 +141,7 @@ class Attack(Metric):
|
|||
dropout: float = 0.5,
|
||||
loss: torch.nn = None,
|
||||
):
|
||||
|
||||
super().__init__(name=name, model=model)
|
||||
self.name = name
|
||||
self.model = model
|
||||
|
@ -148,12 +155,12 @@ class Attack(Metric):
|
|||
]
|
||||
self.dropout = dropout
|
||||
if loss is None:
|
||||
if cfg.model.loss_fun == "cross-entropy":
|
||||
if cfg.model.loss_fun == "cross_entropy":
|
||||
self.loss = CrossEntropyLoss()
|
||||
if cfg.model.loss_fun == "mse":
|
||||
elif cfg.model.loss_fun == "mse":
|
||||
self.loss = MSELoss()
|
||||
else:
|
||||
raise ValueError
|
||||
raise ValueError(f"{loss} is not supported yet")
|
||||
else:
|
||||
self.loss = loss
|
||||
self.load_metric(name)
|
||||
|
|
|
@ -19,7 +19,7 @@ class Sparsity(Metric):
|
|||
|
||||
def forward(self, exp: Explanation) -> float:
|
||||
out = {}
|
||||
for k, v in exp.to_dict():
|
||||
for k, v in exp.to_dict().items():
|
||||
if "mask" in k and v.dtype == torch.bool:
|
||||
out[k] = torch.mean(mask.float()).item()
|
||||
return out
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import types
|
||||
from inspect import getmembers, isfunction, signature
|
||||
|
||||
|
@ -152,7 +153,7 @@ class GraphStat(object):
|
|||
return maps
|
||||
|
||||
def __call__(self, data):
|
||||
data_ = data.__copy__()
|
||||
data_ = copy.copy(data)
|
||||
datahash = hash(data.__repr__)
|
||||
stats = {}
|
||||
for k, v in self.maps.items():
|
||||
|
@ -160,7 +161,7 @@ class GraphStat(object):
|
|||
_data_ = to_networkx(data)
|
||||
_data_ = _data_.to_undirected()
|
||||
elif k == "torch_geometric":
|
||||
_data_ = data.__copy__()
|
||||
_data_ = copy.copy(data)
|
||||
for name, func in v.items():
|
||||
try:
|
||||
val = func(_data_)
|
||||
|
|
|
@ -122,7 +122,7 @@ class LoadModelInfo(object):
|
|||
if self.info is None:
|
||||
self.set_info()
|
||||
|
||||
model_name = os.path.basename(self.info["xp_dir_name"])
|
||||
model_name = os.path.basename(self.info["xp_dir_path"])
|
||||
model_seed = self.info["seed"]
|
||||
epoch = os.path.basename(self.info["ckpt_path"])
|
||||
model_signature = "-".join(
|
||||
|
|
|
@ -4,12 +4,12 @@ from typing import Any
|
|||
from eixgnn.eixgnn import EiXGNN
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
from torch_geometric.loader.dataloader import DataLoader
|
||||
|
||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||
eixgnn_cfg
|
||||
|
@ -53,6 +53,7 @@ all__graphxai = [
|
|||
"RandomExplainer",
|
||||
"SubgraphX",
|
||||
"GraphMASK",
|
||||
"GNNExplainer",
|
||||
]
|
||||
|
||||
all__own = ["EIXGNN", "SCGNN"]
|
||||
|
@ -82,6 +83,7 @@ all_robust = [
|
|||
"pgd",
|
||||
"fgsm",
|
||||
]
|
||||
all_sparsity = ["l0"]
|
||||
|
||||
|
||||
class ExplainingOutline(object):
|
||||
|
@ -108,6 +110,7 @@ class ExplainingOutline(object):
|
|||
self.load_explainer()
|
||||
self.load_metric()
|
||||
self.load_attack()
|
||||
self.load_dataset_to_dataloader()
|
||||
|
||||
def load_model_info(self):
|
||||
info = LoadModelInfo(
|
||||
|
@ -171,7 +174,12 @@ class ExplainingOutline(object):
|
|||
if isinstance(self.explaining_cfg.dataset.specific_items, int):
|
||||
ind = self.explaining_cfg.dataset.specific_items
|
||||
self.dataset = self.dataset[ind : ind + 1]
|
||||
self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1)
|
||||
elif isinstance(self.explaining_cfg.dataset.specific_items, list):
|
||||
ind = self.explaining_cfg.dataset.specific_items
|
||||
self.dataset = self.dataset[ind]
|
||||
|
||||
def load_dataset_to_dataloader(self):
|
||||
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
||||
|
||||
def load_explainer(self):
|
||||
self.load_explainer_cfg()
|
||||
|
@ -217,18 +225,20 @@ class ExplainingOutline(object):
|
|||
if self.explaining_cfg is None:
|
||||
self.load_explaining_cfg()
|
||||
|
||||
name_ = self.explaining_cfg.metrics.type
|
||||
name_ = self.explaining_cfg.metrics.name
|
||||
|
||||
if name_ == "all":
|
||||
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
|
||||
all_fid_metrics = [
|
||||
Fidelity(name=name, model=self.model) for name in all_fidelity
|
||||
]
|
||||
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
||||
self.metrics = all_acc_metrics + all_fid_metrics
|
||||
self.metrics = all_spa_metrics + all_fid_metrics
|
||||
|
||||
if self.explaining_cfg.dataset.name == "BASHAPES":
|
||||
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
||||
self.metrics = self.metrics + all_acc_metrics
|
||||
elif name_ in all_fidelity:
|
||||
self.metrics = [Fidelity(name_)]
|
||||
self.metrics = [Fidelity(name=name_, model=self.model)]
|
||||
elif name_ in all_sparsity:
|
||||
self.metrics = [Sparsity(name_)]
|
||||
elif name_ in all_accuracy:
|
||||
|
@ -250,11 +260,13 @@ class ExplainingOutline(object):
|
|||
self.load_explaining_cfg()
|
||||
name_ = self.explaining_cfg.attack.name
|
||||
if name_ == "all":
|
||||
all_rob_metrics = [Attack(name) for name in all_robust]
|
||||
all_rob_metrics = [
|
||||
Attack(name=name, model=self.model) for name in all_robust
|
||||
]
|
||||
self.attacks = all_rob_metrics
|
||||
elif name_ in all_robust:
|
||||
self.attacks = [Attack(name_)]
|
||||
self.attacks = [Attack(name=name_, model=self.model)]
|
||||
elif name_ is None:
|
||||
slef.attacks = []
|
||||
self.attacks = []
|
||||
else:
|
||||
raise ValueError(f"{name_} is an Attack method that is not supported yet")
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import copy
|
||||
|
||||
import torch
|
||||
from torch import FloatTensor
|
||||
from torch.nn import ReLU
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
|
||||
|
||||
class Adjust(object):
|
||||
def __init__(
|
||||
|
@ -20,7 +23,7 @@ class Adjust(object):
|
|||
self.apply_relu = False
|
||||
|
||||
def forward(self, exp: Explanation) -> Explanation:
|
||||
exp_ = exp.copy()
|
||||
exp_ = copy.copy(exp)
|
||||
_store = exp_.to_dict()
|
||||
for k, v in _store.items():
|
||||
if "mask" in k:
|
||||
|
@ -61,5 +64,7 @@ class Adjust(object):
|
|||
return mask
|
||||
|
||||
def absolute(self, mask: FloatTensor) -> FloatTensor:
|
||||
print("######################### MASK")
|
||||
print(mask)
|
||||
mask_ = torch.abs(mask)
|
||||
return mask_
|
||||
|
|
|
@ -2,6 +2,7 @@ import copy
|
|||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
|
||||
|
@ -12,7 +13,7 @@ def explanation_verification(exp: Explanation) -> bool:
|
|||
for mask in masks:
|
||||
is_nan = mask.isnan().any().item()
|
||||
is_inf = mask.isinf().any().item()
|
||||
is_const = mask.max()==mask.min()
|
||||
is_const = mask.max() == mask.min()
|
||||
is_ok = exp.validate()
|
||||
if is_nan or is_inf or not is_ok or is_const:
|
||||
is_good = False
|
||||
|
@ -25,8 +26,10 @@ def explanation_verification(exp: Explanation) -> bool:
|
|||
def save_explanation(exp: Explanation, path: str) -> None:
|
||||
data = copy.copy(exp).to_dict()
|
||||
for k, v in data.items():
|
||||
print(k, v)
|
||||
if isinstance(v, torch.Tensor):
|
||||
data[k] = v.detach().cpu().tolist()
|
||||
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
|
@ -48,8 +51,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation
|
|||
data = exp.to_dict()
|
||||
for k, v in data.items():
|
||||
if "_mask" in k and isinstance(v, torch.FloatTensor):
|
||||
norm =torch.norm(input=data[k], p=p, dim=None).item()
|
||||
if norm.item()>0:
|
||||
norm = torch.norm(input=data[k], p=p, dim=None).item()
|
||||
if norm.item() > 0:
|
||||
data[k] = data[k] / norm
|
||||
|
||||
return exp
|
||||
|
|
|
@ -31,11 +31,8 @@ def is_exists(path: str) -> bool:
|
|||
|
||||
|
||||
def get_obj_config(obj):
|
||||
config = {k: getattr(obj, k) for k in dir(obj)}
|
||||
config = {
|
||||
k: v
|
||||
for k, v in config.items()
|
||||
if isinstance(v, (int, float, str, bool)) or v is None
|
||||
k: v for k, v in obj.__dict__.items() if isinstance(v, (int, float, str, bool))
|
||||
}
|
||||
return config
|
||||
|
||||
|
|
111
main.py
111
main.py
|
@ -2,9 +2,11 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
import copy
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
from torch_geometric.explain import Explainer
|
||||
|
@ -16,27 +18,39 @@ from explaining_framework.config.explaining_config import explaining_cfg
|
|||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||
from explaining_framework.utils.explanation.adjust import Adjust
|
||||
from explaining_framework.utils.io import (obj_config_to_str, read_json,
|
||||
write_json, write_yaml)
|
||||
from explaining_framework.utils.explanation.io import (
|
||||
explanation_verification, load_explanation, save_explanation)
|
||||
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||
read_json, write_json, write_yaml)
|
||||
|
||||
# inference, time, force,
|
||||
|
||||
|
||||
def get_pred(explanation, force=False):
|
||||
dict_ = explanation.to_dict()
|
||||
if dict_.get("pred") is None or dict_.get("pred_masked") or force:
|
||||
pred = explainer.get_prediction(explanation)
|
||||
def get_pred(explainer, explanation):
|
||||
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[
|
||||
0
|
||||
]
|
||||
setattr(explanation, "pred", pred)
|
||||
data = explanation.to_dict()
|
||||
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
|
||||
pred_masked = explainer.get_masked_prediction(
|
||||
x=explanation.x,
|
||||
edge_index=explanation.edge_index,
|
||||
node_mask=explanation.node_mask,
|
||||
edge_mask=explanation.edge_mask,
|
||||
)
|
||||
explanation.__setattr__("pred", pred)
|
||||
explanation.__setattr__("pred_masked", pred_masked)
|
||||
return explanation
|
||||
else:
|
||||
return explanation
|
||||
node_mask=data.get("node_mask"),
|
||||
edge_mask=data.get("edge_mask"),
|
||||
)[0]
|
||||
setattr(explanation, "pred_exp", pred_masked)
|
||||
|
||||
|
||||
def get_explanation(explainer, item):
|
||||
explanation = explainer(
|
||||
x=item.x,
|
||||
edge_index=item.edge_index,
|
||||
index=int(item.y),
|
||||
target=item.y,
|
||||
)
|
||||
assert explanation_verification(explanation)
|
||||
return explanation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -45,8 +59,9 @@ if __name__ == "__main__":
|
|||
auto_select_device()
|
||||
|
||||
# Load components
|
||||
dataset = outline.dataset.to(cfg.accelerator)
|
||||
dataset = outline.dataset
|
||||
model = outline.model.to(cfg.accelerator)
|
||||
model = model.eval()
|
||||
model_info = outline.model_info
|
||||
metrics = outline.metrics
|
||||
explaining_algorithm = outline.explaining_algorithm
|
||||
|
@ -87,53 +102,57 @@ if __name__ == "__main__":
|
|||
return_type=explaining_cfg.model_config.return_type,
|
||||
),
|
||||
)
|
||||
if not explaining_cfg.dataset.specific_items is None:
|
||||
indexes = explaining_cfg.dataset.specific_items
|
||||
else:
|
||||
indexes = range(len(dataset))
|
||||
# Save explaining configuration
|
||||
for index, item in enumerate(dataset):
|
||||
for index, item in zip(indexes, dataset):
|
||||
item = item.to(cfg.accelerator)
|
||||
save_raw_path = os.path.join(global_path, "raw")
|
||||
makedirs(save_raw_path)
|
||||
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
||||
|
||||
if is_exists(explanation_path):
|
||||
if explaining_cfg.explainer.force:
|
||||
explanation = explainer(
|
||||
x=item.x,
|
||||
edge_index=item.edge_index,
|
||||
index=item.y,
|
||||
target=item.y,
|
||||
)
|
||||
explanation = get_explanation(explainer, item)
|
||||
else:
|
||||
explanation = load_explanation(explanation_path)
|
||||
else:
|
||||
explanation = explainer(
|
||||
x=item.x,
|
||||
edge_index=item.edge_index,
|
||||
index=item.y,
|
||||
target=item.y,
|
||||
)
|
||||
explanation = get_pred(explanation, force=False)
|
||||
explanation = get_explanation(explainer, item)
|
||||
|
||||
explanation = explanation.to(cfg.accelerator)
|
||||
get_pred(explainer=explainer, explanation=explanation)
|
||||
|
||||
save_explanation(explanation, explanation_path)
|
||||
for apply_relu in [True, False]:
|
||||
for apply_absolute in [True, False]:
|
||||
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
|
||||
save_raw_path = os.path.join(
|
||||
save_raw_path_ = os.path.join(
|
||||
global_path, f"adjust-{obj_config_to_str(adjust)}"
|
||||
)
|
||||
makedirs(save_raw_path)
|
||||
explanation = adjust.forward(explanation)
|
||||
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
||||
explanation = get_pred(explanation, force=True)
|
||||
save_explanation(explanation, explanation_path)
|
||||
explanation__ = copy.copy(explanation).to(cfg.accelerator)
|
||||
makedirs(save_raw_path_)
|
||||
explanation = adjust.forward(explanation__)
|
||||
explanation_path = os.path.join(save_raw_path_, f"{index}.json")
|
||||
get_pred(explainer, explanation__)
|
||||
save_explanation(explanation__, explanation_path)
|
||||
|
||||
for threshold_approach in ["hard", "topk", "topk_hard"]:
|
||||
for threshold_value in explaining_cfg.threshold_config.value:
|
||||
if threshold_approach == "hard":
|
||||
threshold_values = explaining_cfg.threshold_config.value
|
||||
elif "topk" in threshold_approach:
|
||||
threshold_values = [3, 5, 10, 20]
|
||||
for threshold_value in threshold_values:
|
||||
|
||||
masking_path = os.path.join(
|
||||
save_raw_path,
|
||||
f"threshold={threshold_approach}-value={value}",
|
||||
save_raw_path_,
|
||||
f"threshold={threshold_approach}-value={threshold_value}",
|
||||
)
|
||||
makedirs(masking_path)
|
||||
exp_threshold_path = os.path.join(masking_path, f"{index}.json")
|
||||
if is_exists(exp_threshold_path):
|
||||
explanation = load_explanation(exp_threshold_path)
|
||||
exp_threshold = load_explanation(exp_threshold_path)
|
||||
else:
|
||||
threshold_conf = {
|
||||
"threshold_type": threshold_approach,
|
||||
|
@ -143,17 +162,21 @@ if __name__ == "__main__":
|
|||
threshold_conf
|
||||
)
|
||||
|
||||
expl = copy.copy(explanation)
|
||||
expl = copy.copy(explanation__).to(cfg.accelerator)
|
||||
exp_threshold = explainer._post_process(expl)
|
||||
exp_threshold = get_pred(exp_threshold, force=True)
|
||||
|
||||
save_explanation(exp_threshold, exp_threshold_path)
|
||||
exp_threshold = exp_threshold.to(cfg.accelerator)
|
||||
get_pred(explainer, exp_threshold)
|
||||
save_explanation(exp_threshold, exp_threshold_path)
|
||||
for metric in metrics:
|
||||
metric_path = os.path.join(
|
||||
masking_path, f"{obj_config_to_str(metric)}"
|
||||
)
|
||||
makedirs(metric_path)
|
||||
if is_exists(os.path.join(metric_path, f"{index}.json")):
|
||||
continue
|
||||
else:
|
||||
out = metric.forward(exp_threshold)
|
||||
write_json({f"{metric.name}": out})
|
||||
write_json(
|
||||
{f"{metric.name}": out},
|
||||
os.path.join(metric_path, f"{index}.json"),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue