Fixing bugs and adding new features

This commit is contained in:
araison 2023-01-02 23:37:40 +01:00
parent db0be0ddb7
commit 9ad5adb33e
13 changed files with 183 additions and 120 deletions

View File

@ -114,7 +114,7 @@ def set_cfg(explaining_cfg):
explaining_cfg.threshold_config.threshold_type = None explaining_cfg.threshold_config.threshold_type = None
explaining_cfg.threshold_config.value = [0.3, 0.5, 0.7] explaining_cfg.threshold_config.value = [i * 0.05 for i in range(21)]
explaining_cfg.threshold_config.relu_and_normalize = True explaining_cfg.threshold_config.relu_and_normalize = True
@ -128,8 +128,7 @@ def set_cfg(explaining_cfg):
explaining_cfg.metrics.force = False explaining_cfg.metrics.force = False
explaining_cfg.attack = CN() explaining_cfg.attack = CN()
explaining_cfg.attack.name = 'all' explaining_cfg.attack.name = "all"
explaining_cfg.accelerator = "auto" explaining_cfg.accelerator = "auto"

View File

@ -31,16 +31,16 @@ __all__captum = [
__all__graphxai = [ __all__graphxai = [
"CAM", "CAM",
# "GradCAM", "GradCAM",
# "GNN_LRP", "GNN_LRP",
# "GradExplainer", "GradExplainer",
# "GuidedBackPropagation", "GuidedBackPropagation",
# "IntegratedGradients", "IntegratedGradients",
# "PGExplainer", "PGExplainer",
# "PGMExplainer", "PGMExplainer",
# "RandomExplainer", "RandomExplainer",
# "SubgraphX", "SubgraphX",
# "GraphMASK", "GraphMASK",
] ]
@ -82,12 +82,12 @@ for epoch in range(1, 2):
target = torch.LongTensor([[0]]) target = torch.LongTensor([[0]])
for kind in ["graph"]: for kind in ["graph"]:
for name in __all__graphxai: for name in __all__graphxai + __all__captum:
if name in __all__captum: if name in __all__captum:
explaining_algorithm = CaptumWrapper(name) explaining_algorithm = CaptumWrapper(name)
elif name in __all__graphxai: elif name in __all__graphxai:
explaining_algorithm = GraphXAIWrapper( explaining_algorithm = GraphXAIWrapper(
name, in_channels=in_channels, criterion="cross-entropy" name, in_channels=in_channels, criterion="cross_entropy"
) )
print(name) print(name)
@ -105,7 +105,7 @@ for kind in ["graph"]:
task_level=kind, task_level=kind,
return_type="raw", return_type="raw",
), ),
threshold_config=dict(threshold_type="hard", value=0.5), # threshold_config=dict(threshold_type=None, value=0.5),
) )
explanation = explainer( explanation = explainer(
x=batch.x, x=batch.x,
@ -117,29 +117,29 @@ for kind in ["graph"]:
# explanation.__setattr__( # explanation.__setattr__(
# "model_prediction", explainer.get_prediction(x, edge_index) # "model_prediction", explainer.get_prediction(x, edge_index)
# ) # )
explanation_threshold = explanation._apply_masks( # explanation_threshold = explanation._apply_masks(
node_mask=torch.ones_like(explanation.node_mask).bool() # node_mask=torch.ones_like(explanation.node_mask).bool()
) # )
print(explanation_threshold.__dict__) # print(explanation_threshold.__dict__)
for f_name in [ # for f_name in [
"gaussian_noise", # "gaussian_noise",
"add_edge", # "add_edge",
"remove_edge", # "remove_edge",
"remove_node", # "remove_node",
"pgd", # "pgd",
"fgsm", # "fgsm",
]: # ]:
print(f_name) # print(f_name)
acc = Attack(name=f_name, model=model, loss=loss) # acc = Attack(name=f_name, model=model, loss=loss)
# gt = torch.ones_like(explanation_threshold.node_mask) / 2 # # gt = torch.ones_like(explanation_threshold.node_mask) / 2
# mask = explanation_threshold.node_mask.bool() # # mask = explanation_threshold.node_mask.bool()
# target = (1 - gt).bool() # # target = (1 - gt).bool()
# target[1] = False # # target[1] = False
# print(mask, target) # # print(mask, target)
out = acc.forward(explanation) # out = acc.forward(explanation)
print(out) # print(out)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()

View File

@ -39,14 +39,9 @@ class Metric(ABC):
**kwargs (optional): Additional keyword arguments passed to the **kwargs (optional): Additional keyword arguments passed to the
model. model.
""" """
training = self.model.training print(args, kwargs)
self.model.eval()
with torch.no_grad(): with torch.no_grad():
out = self.model(*args, **kwargs) out = self.model(*args, **kwargs)[0]
self.model.train(training)
return out return out

View File

@ -1,9 +1,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn import KLDivLoss, Softmax from torch.nn import KLDivLoss, Softmax
from torch_geometric.explain.explanation import Explanation from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.config import cfg
from torch import Tensor
from explaining_framework.metric.base import Metric from explaining_framework.metric.base import Metric
@ -118,9 +118,19 @@ class Fidelity(Metric):
) )
pos_fidelity = self._fidelity_plus_prob(exp) pos_fidelity = self._fidelity_plus_prob(exp)
neg_fidelity = self._fidelity_minus_prob(exp) neg_fidelity = self._fidelity_minus_prob(exp)
if (
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity)) pos_fidelity == 0
return 1.0 / denom or pos_fidelity == 1
or neg_fidelity == 0
or neg_fidelity == 1
):
return None
else:
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
if denom == 0:
return None
else:
return 1.0 / denom
def _characterization( def _characterization(
self, self,
@ -136,8 +146,19 @@ class Fidelity(Metric):
pos_fidelity = self._fidelity_plus(exp) pos_fidelity = self._fidelity_plus(exp)
neg_fidelity = self._fidelity_minus(exp) neg_fidelity = self._fidelity_minus(exp)
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity)) if (
return 1.0 / denom pos_fidelity == 0
or pos_fidelity == 1
or neg_fidelity == 0
or neg_fidelity == 1
):
return None
else:
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
if denom == 0:
return None
else:
return 1.0 / denom
def score(self, exp): def score(self, exp):
self.exp_sub = exp.get_explanation_subgraph() self.exp_sub = exp.get_explanation_subgraph()

View File

@ -26,7 +26,7 @@ class FGSM(Metric):
lower_bound: float = float("-inf"), lower_bound: float = float("-inf"),
upper_bound: float = float("inf"), upper_bound: float = float("inf"),
): ):
super().__init__(name=name, model=model) super().__init__(name="fgsm", model=model)
self.model = model self.model = model
self.loss = loss self.loss = loss
self.lower_bound = lower_bound self.lower_bound = lower_bound
@ -51,6 +51,9 @@ class FGSM(Metric):
) )
return input_ return input_
def load_metric(self):
pass
class PGD(Metric): class PGD(Metric):
def __init__( def __init__(
@ -60,7 +63,7 @@ class PGD(Metric):
lower_bound: float = float("-inf"), lower_bound: float = float("-inf"),
upper_bound: float = float("inf"), upper_bound: float = float("inf"),
): ):
super().__init__(name=name, model=model) super().__init__(name="pgd", model=model)
self.model = model self.model = model
self.loss = loss self.loss = loss
self.lower_bound = lower_bound self.lower_bound = lower_bound
@ -105,6 +108,9 @@ class PGD(Metric):
perturbed_input.x = self.bound(perturbed_input.x).detach() perturbed_input.x = self.bound(perturbed_input.x).detach()
return perturbed_input return perturbed_input
def load_metric(self):
pass
def _random_point( def _random_point(
self, center: torch.Tensor, radius: float, norm: str self, center: torch.Tensor, radius: float, norm: str
) -> torch.Tensor: ) -> torch.Tensor:
@ -135,6 +141,7 @@ class Attack(Metric):
dropout: float = 0.5, dropout: float = 0.5,
loss: torch.nn = None, loss: torch.nn = None,
): ):
super().__init__(name=name, model=model) super().__init__(name=name, model=model)
self.name = name self.name = name
self.model = model self.model = model
@ -148,12 +155,12 @@ class Attack(Metric):
] ]
self.dropout = dropout self.dropout = dropout
if loss is None: if loss is None:
if cfg.model.loss_fun == "cross-entropy": if cfg.model.loss_fun == "cross_entropy":
self.loss = CrossEntropyLoss() self.loss = CrossEntropyLoss()
if cfg.model.loss_fun == "mse": elif cfg.model.loss_fun == "mse":
self.loss = MSELoss() self.loss = MSELoss()
else: else:
raise ValueError raise ValueError(f"{loss} is not supported yet")
else: else:
self.loss = loss self.loss = loss
self.load_metric(name) self.load_metric(name)

View File

@ -19,7 +19,7 @@ class Sparsity(Metric):
def forward(self, exp: Explanation) -> float: def forward(self, exp: Explanation) -> float:
out = {} out = {}
for k, v in exp.to_dict(): for k, v in exp.to_dict().items():
if "mask" in k and v.dtype == torch.bool: if "mask" in k and v.dtype == torch.bool:
out[k] = torch.mean(mask.float()).item() out[k] = torch.mean(mask.float()).item()
return out return out

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import copy
import types import types
from inspect import getmembers, isfunction, signature from inspect import getmembers, isfunction, signature
@ -152,7 +153,7 @@ class GraphStat(object):
return maps return maps
def __call__(self, data): def __call__(self, data):
data_ = data.__copy__() data_ = copy.copy(data)
datahash = hash(data.__repr__) datahash = hash(data.__repr__)
stats = {} stats = {}
for k, v in self.maps.items(): for k, v in self.maps.items():
@ -160,7 +161,7 @@ class GraphStat(object):
_data_ = to_networkx(data) _data_ = to_networkx(data)
_data_ = _data_.to_undirected() _data_ = _data_.to_undirected()
elif k == "torch_geometric": elif k == "torch_geometric":
_data_ = data.__copy__() _data_ = copy.copy(data)
for name, func in v.items(): for name, func in v.items():
try: try:
val = func(_data_) val = func(_data_)

View File

@ -122,7 +122,7 @@ class LoadModelInfo(object):
if self.info is None: if self.info is None:
self.set_info() self.set_info()
model_name = os.path.basename(self.info["xp_dir_name"]) model_name = os.path.basename(self.info["xp_dir_path"])
model_seed = self.info["seed"] model_seed = self.info["seed"]
epoch = os.path.basename(self.info["ckpt_path"]) epoch = os.path.basename(self.info["ckpt_path"])
model_signature = "-".join( model_signature = "-".join(

View File

@ -4,12 +4,12 @@ from typing import Any
from eixgnn.eixgnn import EiXGNN from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN from scgnn.scgnn import SCGNN
from torch_geometric.data import Batch, Data from torch_geometric.data import Batch, Data
from torch_geometric.loader.dataloader import DataLoader
from torch_geometric.explain import Explainer from torch_geometric.explain import Explainer
from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from explaining_framework.config.explainer_config.eixgnn_config import \ from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg eixgnn_cfg
@ -53,6 +53,7 @@ all__graphxai = [
"RandomExplainer", "RandomExplainer",
"SubgraphX", "SubgraphX",
"GraphMASK", "GraphMASK",
"GNNExplainer",
] ]
all__own = ["EIXGNN", "SCGNN"] all__own = ["EIXGNN", "SCGNN"]
@ -82,6 +83,7 @@ all_robust = [
"pgd", "pgd",
"fgsm", "fgsm",
] ]
all_sparsity = ["l0"]
class ExplainingOutline(object): class ExplainingOutline(object):
@ -108,6 +110,7 @@ class ExplainingOutline(object):
self.load_explainer() self.load_explainer()
self.load_metric() self.load_metric()
self.load_attack() self.load_attack()
self.load_dataset_to_dataloader()
def load_model_info(self): def load_model_info(self):
info = LoadModelInfo( info = LoadModelInfo(
@ -171,7 +174,12 @@ class ExplainingOutline(object):
if isinstance(self.explaining_cfg.dataset.specific_items, int): if isinstance(self.explaining_cfg.dataset.specific_items, int):
ind = self.explaining_cfg.dataset.specific_items ind = self.explaining_cfg.dataset.specific_items
self.dataset = self.dataset[ind : ind + 1] self.dataset = self.dataset[ind : ind + 1]
self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1) elif isinstance(self.explaining_cfg.dataset.specific_items, list):
ind = self.explaining_cfg.dataset.specific_items
self.dataset = self.dataset[ind]
def load_dataset_to_dataloader(self):
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
def load_explainer(self): def load_explainer(self):
self.load_explainer_cfg() self.load_explainer_cfg()
@ -217,18 +225,20 @@ class ExplainingOutline(object):
if self.explaining_cfg is None: if self.explaining_cfg is None:
self.load_explaining_cfg() self.load_explaining_cfg()
name_ = self.explaining_cfg.metrics.type name_ = self.explaining_cfg.metrics.name
if name_ == "all": if name_ == "all":
all_fid_metrics = [Fidelity(name) for name in all_fidelity] all_fid_metrics = [
Fidelity(name=name, model=self.model) for name in all_fidelity
]
all_spa_metrics = [Sparsity(name) for name in all_sparsity] all_spa_metrics = [Sparsity(name) for name in all_sparsity]
self.metrics = all_acc_metrics + all_fid_metrics self.metrics = all_spa_metrics + all_fid_metrics
if self.explaining_cfg.dataset.name == "BASHAPES": if self.explaining_cfg.dataset.name == "BASHAPES":
all_acc_metrics = [Accuracy(name) for name in all_accuracy] all_acc_metrics = [Accuracy(name) for name in all_accuracy]
self.metrics = self.metrics + all_acc_metrics self.metrics = self.metrics + all_acc_metrics
elif name_ in all_fidelity: elif name_ in all_fidelity:
self.metrics = [Fidelity(name_)] self.metrics = [Fidelity(name=name_, model=self.model)]
elif name_ in all_sparsity: elif name_ in all_sparsity:
self.metrics = [Sparsity(name_)] self.metrics = [Sparsity(name_)]
elif name_ in all_accuracy: elif name_ in all_accuracy:
@ -250,11 +260,13 @@ class ExplainingOutline(object):
self.load_explaining_cfg() self.load_explaining_cfg()
name_ = self.explaining_cfg.attack.name name_ = self.explaining_cfg.attack.name
if name_ == "all": if name_ == "all":
all_rob_metrics = [Attack(name) for name in all_robust] all_rob_metrics = [
Attack(name=name, model=self.model) for name in all_robust
]
self.attacks = all_rob_metrics self.attacks = all_rob_metrics
elif name_ in all_robust: elif name_ in all_robust:
self.attacks = [Attack(name_)] self.attacks = [Attack(name=name_, model=self.model)]
elif name_ is None: elif name_ is None:
slef.attacks = [] self.attacks = []
else: else:
raise ValueError(f"{name_} is an Attack method that is not supported yet") raise ValueError(f"{name_} is an Attack method that is not supported yet")

View File

@ -1,7 +1,10 @@
import copy import copy
import torch
from torch import FloatTensor from torch import FloatTensor
from torch.nn import ReLU from torch.nn import ReLU
from torch_geometric.explain.explanation import Explanation
class Adjust(object): class Adjust(object):
def __init__( def __init__(
@ -20,7 +23,7 @@ class Adjust(object):
self.apply_relu = False self.apply_relu = False
def forward(self, exp: Explanation) -> Explanation: def forward(self, exp: Explanation) -> Explanation:
exp_ = exp.copy() exp_ = copy.copy(exp)
_store = exp_.to_dict() _store = exp_.to_dict()
for k, v in _store.items(): for k, v in _store.items():
if "mask" in k: if "mask" in k:
@ -61,5 +64,7 @@ class Adjust(object):
return mask return mask
def absolute(self, mask: FloatTensor) -> FloatTensor: def absolute(self, mask: FloatTensor) -> FloatTensor:
print("######################### MASK")
print(mask)
mask_ = torch.abs(mask) mask_ = torch.abs(mask)
return mask_ return mask_

View File

@ -2,6 +2,7 @@ import copy
import json import json
import os import os
import torch
from torch_geometric.data import Data from torch_geometric.data import Data
from torch_geometric.explain.explanation import Explanation from torch_geometric.explain.explanation import Explanation
@ -12,7 +13,7 @@ def explanation_verification(exp: Explanation) -> bool:
for mask in masks: for mask in masks:
is_nan = mask.isnan().any().item() is_nan = mask.isnan().any().item()
is_inf = mask.isinf().any().item() is_inf = mask.isinf().any().item()
is_const = mask.max()==mask.min() is_const = mask.max() == mask.min()
is_ok = exp.validate() is_ok = exp.validate()
if is_nan or is_inf or not is_ok or is_const: if is_nan or is_inf or not is_ok or is_const:
is_good = False is_good = False
@ -25,8 +26,10 @@ def explanation_verification(exp: Explanation) -> bool:
def save_explanation(exp: Explanation, path: str) -> None: def save_explanation(exp: Explanation, path: str) -> None:
data = copy.copy(exp).to_dict() data = copy.copy(exp).to_dict()
for k, v in data.items(): for k, v in data.items():
print(k, v)
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
data[k] = v.detach().cpu().tolist() data[k] = v.detach().cpu().tolist()
with open(path, "w") as f: with open(path, "w") as f:
json.dump(data, f) json.dump(data, f)
@ -48,8 +51,8 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation
data = exp.to_dict() data = exp.to_dict()
for k, v in data.items(): for k, v in data.items():
if "_mask" in k and isinstance(v, torch.FloatTensor): if "_mask" in k and isinstance(v, torch.FloatTensor):
norm =torch.norm(input=data[k], p=p, dim=None).item() norm = torch.norm(input=data[k], p=p, dim=None).item()
if norm.item()>0: if norm.item() > 0:
data[k] = data[k] / norm data[k] = data[k] / norm
return exp return exp

View File

@ -31,11 +31,8 @@ def is_exists(path: str) -> bool:
def get_obj_config(obj): def get_obj_config(obj):
config = {k: getattr(obj, k) for k in dir(obj)}
config = { config = {
k: v k: v for k, v in obj.__dict__.items() if isinstance(v, (int, float, str, bool))
for k, v in config.items()
if isinstance(v, (int, float, str, bool)) or v is None
} }
return config return config

111
main.py
View File

@ -2,9 +2,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import copy
import os import os
import time import time
import torch
from torch_geometric import seed_everything from torch_geometric import seed_everything
from torch_geometric.data.makedirs import makedirs from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer from torch_geometric.explain import Explainer
@ -16,27 +18,39 @@ from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline from explaining_framework.utils.explaining.outline import ExplainingOutline
from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.adjust import Adjust
from explaining_framework.utils.io import (obj_config_to_str, read_json, from explaining_framework.utils.explanation.io import (
write_json, write_yaml) explanation_verification, load_explanation, save_explanation)
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
read_json, write_json, write_yaml)
# inference, time, force, # inference, time, force,
def get_pred(explanation, force=False): def get_pred(explainer, explanation):
dict_ = explanation.to_dict() pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[
if dict_.get("pred") is None or dict_.get("pred_masked") or force: 0
pred = explainer.get_prediction(explanation) ]
setattr(explanation, "pred", pred)
data = explanation.to_dict()
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
pred_masked = explainer.get_masked_prediction( pred_masked = explainer.get_masked_prediction(
x=explanation.x, x=explanation.x,
edge_index=explanation.edge_index, edge_index=explanation.edge_index,
node_mask=explanation.node_mask, node_mask=data.get("node_mask"),
edge_mask=explanation.edge_mask, edge_mask=data.get("edge_mask"),
) )[0]
explanation.__setattr__("pred", pred) setattr(explanation, "pred_exp", pred_masked)
explanation.__setattr__("pred_masked", pred_masked)
return explanation
else: def get_explanation(explainer, item):
return explanation explanation = explainer(
x=item.x,
edge_index=item.edge_index,
index=int(item.y),
target=item.y,
)
assert explanation_verification(explanation)
return explanation
if __name__ == "__main__": if __name__ == "__main__":
@ -45,8 +59,9 @@ if __name__ == "__main__":
auto_select_device() auto_select_device()
# Load components # Load components
dataset = outline.dataset.to(cfg.accelerator) dataset = outline.dataset
model = outline.model.to(cfg.accelerator) model = outline.model.to(cfg.accelerator)
model = model.eval()
model_info = outline.model_info model_info = outline.model_info
metrics = outline.metrics metrics = outline.metrics
explaining_algorithm = outline.explaining_algorithm explaining_algorithm = outline.explaining_algorithm
@ -87,53 +102,57 @@ if __name__ == "__main__":
return_type=explaining_cfg.model_config.return_type, return_type=explaining_cfg.model_config.return_type,
), ),
) )
if not explaining_cfg.dataset.specific_items is None:
indexes = explaining_cfg.dataset.specific_items
else:
indexes = range(len(dataset))
# Save explaining configuration # Save explaining configuration
for index, item in enumerate(dataset): for index, item in zip(indexes, dataset):
item = item.to(cfg.accelerator)
save_raw_path = os.path.join(global_path, "raw") save_raw_path = os.path.join(global_path, "raw")
makedirs(save_raw_path) makedirs(save_raw_path)
explanation_path = os.path.join(save_raw_path, f"{index}.json") explanation_path = os.path.join(save_raw_path, f"{index}.json")
if is_exists(explanation_path): if is_exists(explanation_path):
if explaining_cfg.explainer.force: if explaining_cfg.explainer.force:
explanation = explainer( explanation = get_explanation(explainer, item)
x=item.x,
edge_index=item.edge_index,
index=item.y,
target=item.y,
)
else: else:
explanation = load_explanation(explanation_path) explanation = load_explanation(explanation_path)
else: else:
explanation = explainer( explanation = get_explanation(explainer, item)
x=item.x,
edge_index=item.edge_index, explanation = explanation.to(cfg.accelerator)
index=item.y, get_pred(explainer=explainer, explanation=explanation)
target=item.y,
)
explanation = get_pred(explanation, force=False)
save_explanation(explanation, explanation_path) save_explanation(explanation, explanation_path)
for apply_relu in [True, False]: for apply_relu in [True, False]:
for apply_absolute in [True, False]: for apply_absolute in [True, False]:
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute) adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)
save_raw_path = os.path.join( save_raw_path_ = os.path.join(
global_path, f"adjust-{obj_config_to_str(adjust)}" global_path, f"adjust-{obj_config_to_str(adjust)}"
) )
makedirs(save_raw_path) explanation__ = copy.copy(explanation).to(cfg.accelerator)
explanation = adjust.forward(explanation) makedirs(save_raw_path_)
explanation_path = os.path.join(save_raw_path, f"{index}.json") explanation = adjust.forward(explanation__)
explanation = get_pred(explanation, force=True) explanation_path = os.path.join(save_raw_path_, f"{index}.json")
save_explanation(explanation, explanation_path) get_pred(explainer, explanation__)
save_explanation(explanation__, explanation_path)
for threshold_approach in ["hard", "topk", "topk_hard"]: for threshold_approach in ["hard", "topk", "topk_hard"]:
for threshold_value in explaining_cfg.threshold_config.value: if threshold_approach == "hard":
threshold_values = explaining_cfg.threshold_config.value
elif "topk" in threshold_approach:
threshold_values = [3, 5, 10, 20]
for threshold_value in threshold_values:
masking_path = os.path.join( masking_path = os.path.join(
save_raw_path, save_raw_path_,
f"threshold={threshold_approach}-value={value}", f"threshold={threshold_approach}-value={threshold_value}",
) )
makedirs(masking_path)
exp_threshold_path = os.path.join(masking_path, f"{index}.json") exp_threshold_path = os.path.join(masking_path, f"{index}.json")
if is_exists(exp_threshold_path): if is_exists(exp_threshold_path):
explanation = load_explanation(exp_threshold_path) exp_threshold = load_explanation(exp_threshold_path)
else: else:
threshold_conf = { threshold_conf = {
"threshold_type": threshold_approach, "threshold_type": threshold_approach,
@ -143,17 +162,21 @@ if __name__ == "__main__":
threshold_conf threshold_conf
) )
expl = copy.copy(explanation) expl = copy.copy(explanation__).to(cfg.accelerator)
exp_threshold = explainer._post_process(expl) exp_threshold = explainer._post_process(expl)
exp_threshold = get_pred(exp_threshold, force=True) exp_threshold = exp_threshold.to(cfg.accelerator)
get_pred(explainer, exp_threshold)
save_explanation(exp_threshold, exp_threshold_path) save_explanation(exp_threshold, exp_threshold_path)
for metric in metrics: for metric in metrics:
metric_path = os.path.join( metric_path = os.path.join(
masking_path, f"{obj_config_to_str(metric)}" masking_path, f"{obj_config_to_str(metric)}"
) )
makedirs(metric_path)
if is_exists(os.path.join(metric_path, f"{index}.json")): if is_exists(os.path.join(metric_path, f"{index}.json")):
continue continue
else: else:
out = metric.forward(exp_threshold) out = metric.forward(exp_threshold)
write_json({f"{metric.name}": out}) write_json(
{f"{metric.name}": out},
os.path.join(metric_path, f"{index}.json"),
)