Reformating, fixing many bugs
This commit is contained in:
parent
10baa1d443
commit
fbc685503c
|
@ -37,12 +37,12 @@ def set_eixgnn_cfg(eixgnn_cfg):
|
||||||
return eixgnn_cfg
|
return eixgnn_cfg
|
||||||
|
|
||||||
eixgnn_cfg.seed = 0
|
eixgnn_cfg.seed = 0
|
||||||
eixgnn_cfg.L = 50
|
eixgnn_cfg.L = 5
|
||||||
eixgnn_cfg.p = 0.5
|
eixgnn_cfg.p = 0.1
|
||||||
eixgnn_cfg.importance_sampling_strategy = "node"
|
eixgnn_cfg.importance_sampling_strategy = "neighborhood"
|
||||||
eixgnn_cfg.domain_similarity = "relative_edge_density"
|
eixgnn_cfg.domain_similarity = "relative_edge_density"
|
||||||
eixgnn_cfg.signal_similarity = "KL"
|
eixgnn_cfg.signal_similarity = "KL"
|
||||||
eixgnn_cfg.shapley_value_approx = 100
|
eixgnn_cfg.shapley_value_approx = 20
|
||||||
|
|
||||||
|
|
||||||
def assert_eixgnn_cfg(eixgnn_cfg):
|
def assert_eixgnn_cfg(eixgnn_cfg):
|
||||||
|
|
|
@ -39,6 +39,7 @@ def set_scgnn_cfg(scgnn_cfg):
|
||||||
scgnn_cfg.depth = "all"
|
scgnn_cfg.depth = "all"
|
||||||
scgnn_cfg.interest_map_norm = True
|
scgnn_cfg.interest_map_norm = True
|
||||||
scgnn_cfg.score_map_norm = True
|
scgnn_cfg.score_map_norm = True
|
||||||
|
scgnn_cfg.target_baseline = "inference"
|
||||||
|
|
||||||
|
|
||||||
def assert_cfg(scgnn_cfg):
|
def assert_cfg(scgnn_cfg):
|
||||||
|
|
|
@ -57,9 +57,7 @@ def set_cfg(explaining_cfg):
|
||||||
|
|
||||||
explaining_cfg.dataset.name = "Cora"
|
explaining_cfg.dataset.name = "Cora"
|
||||||
|
|
||||||
explaining_cfg.dataset.items = None
|
explaining_cfg.dataset.item = None
|
||||||
|
|
||||||
explaining_cfg.run_topological_stat = True
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
# ----------------------------------------------------------------------- #
|
||||||
# Model options
|
# Model options
|
||||||
|
@ -116,7 +114,7 @@ def set_cfg(explaining_cfg):
|
||||||
explaining_cfg.threshold.config.type = "all"
|
explaining_cfg.threshold.config.type = "all"
|
||||||
|
|
||||||
explaining_cfg.threshold.value = CN()
|
explaining_cfg.threshold.value = CN()
|
||||||
explaining_cfg.threshold.value.hard = [i * 0.05 for i in range(21)]
|
explaining_cfg.threshold.value.hard = [(i * 10) / 100 for i in range(1, 10)]
|
||||||
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
|
explaining_cfg.threshold.value.topk = [2, 3, 5, 10, 20, 30, 50]
|
||||||
|
|
||||||
# which objectives metrics to computes, either all or one in particular if implemented
|
# which objectives metrics to computes, either all or one in particular if implemented
|
||||||
|
@ -131,7 +129,7 @@ def set_cfg(explaining_cfg):
|
||||||
# Whether or not recomputing metrics if they already exist
|
# Whether or not recomputing metrics if they already exist
|
||||||
|
|
||||||
explaining_cfg.adjust = CN()
|
explaining_cfg.adjust = CN()
|
||||||
explaining_cfg.adjust.strategy = "rpn"
|
explaining_cfg.adjust.strategy = "rpns"
|
||||||
|
|
||||||
explaining_cfg.attack = CN()
|
explaining_cfg.attack = CN()
|
||||||
explaining_cfg.attack.name = "all"
|
explaining_cfg.attack.name = "all"
|
||||||
|
|
|
@ -37,10 +37,7 @@ def _load_GNN_LRP(model):
|
||||||
|
|
||||||
|
|
||||||
def _load_GuidedBackPropagation(model, criterion):
|
def _load_GuidedBackPropagation(model, criterion):
|
||||||
# return lambda model: GuidedBP(model, criterion)
|
return GuidedBP(model, criterion)
|
||||||
raise ValueError(
|
|
||||||
"GraphXAI GuidedBackPropagation is discarded since already available in Captum for Pytorch Geometric (see CaptumWrapper)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_IntegratedGradients(model, criterion):
|
def _load_IntegratedGradients(model, criterion):
|
||||||
|
@ -106,8 +103,8 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
"GradCAM",
|
"GradCAM",
|
||||||
"GNN_LRP",
|
"GNN_LRP",
|
||||||
"GradExplainer",
|
"GradExplainer",
|
||||||
"GuidedBP",
|
"GuidedBackPropagation",
|
||||||
"IntegratedGradExplainer",
|
"IntegratedGraddients",
|
||||||
"PGExplainer",
|
"PGExplainer",
|
||||||
"PGMExplainer",
|
"PGMExplainer",
|
||||||
"RandomExplainer",
|
"RandomExplainer",
|
||||||
|
@ -234,10 +231,12 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
index: Optional[Union[int, Tensor]] = None,
|
index: Optional[Union[int, Tensor]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
mask_type = self._get_mask_type()
|
mask_type = self._get_mask_type()
|
||||||
self.graphxai_method = self._load_graphxai_method(model)
|
self.graphxai_method = self._load_graphxai_method(model)
|
||||||
|
|
||||||
if self.model_config.task_level == ModelTaskLevel.node:
|
if self.model_config.task_level == ModelTaskLevel.node:
|
||||||
|
|
||||||
attr = self.graphxai_method.get_explanation_node(
|
attr = self.graphxai_method.get_explanation_node(
|
||||||
x=x,
|
x=x,
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
|
@ -245,18 +244,26 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
||||||
node_idx=index,
|
node_idx=index,
|
||||||
y=target,
|
y=target,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.model_config.task_level == ModelTaskLevel.graph:
|
elif self.model_config.task_level == ModelTaskLevel.graph:
|
||||||
|
|
||||||
attr = self.graphxai_method.get_explanation_graph(
|
attr = self.graphxai_method.get_explanation_graph(
|
||||||
x=x,
|
x=x,
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
label=target,
|
label=target,
|
||||||
y=target,
|
y=target,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.model_config.task_level == ModelTaskLevel.edge:
|
elif self.model_config.task_level == ModelTaskLevel.edge:
|
||||||
|
|
||||||
attr = self.graphxai_method.get_explanation_link(*args, **kwargs)
|
attr = self.graphxai_method.get_explanation_link(*args, **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
raise ValueError(f"{self.model_config.task_level} is not supported yet")
|
raise ValueError(f"{self.model_config.task_level} is not supported yet")
|
||||||
|
|
||||||
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
|
node_mask, edge_mask, node_feat_mask, edge_feat_mask = self._parse_attr(attr)
|
||||||
|
|
||||||
return Explanation(
|
return Explanation(
|
||||||
x=x,
|
x=x,
|
||||||
edge_index=edge_index,
|
edge_index=edge_index,
|
||||||
|
|
|
@ -2,16 +2,16 @@ import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from explaining_framework.metric.accuracy import Accuracy
|
from from_captum import CaptumWrapper
|
||||||
from explaining_framework.metric.fidelity import Fidelity
|
from from_graphxai import GraphXAIWrapper
|
||||||
from explaining_framework.metric.robust import Attack
|
|
||||||
from explaining_framework.metric.sparsity import Sparsity
|
|
||||||
from torch_geometric.data import Batch, Data
|
from torch_geometric.data import Batch, Data
|
||||||
from torch_geometric.explain import Explainer
|
from torch_geometric.explain import Explainer
|
||||||
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
|
from torch_geometric.nn import GATConv, GCNConv, GINConv, global_mean_pool
|
||||||
|
|
||||||
from from_captum import CaptumWrapper
|
from explaining_framework.metric.accuracy import Accuracy
|
||||||
from from_graphxai import GraphXAIWrapper
|
from explaining_framework.metric.fidelity import Fidelity
|
||||||
|
from explaining_framework.metric.robust import Attack
|
||||||
|
from explaining_framework.metric.sparsity import Sparsity
|
||||||
|
|
||||||
__all__captum = [
|
__all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
|
|
@ -41,6 +41,6 @@ class Metric(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
out = self.model(*args, **kwargs)[0]
|
out = self.model(*args, **kwargs)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from explaining_framework.metric.base import Metric
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import KLDivLoss, Softmax
|
from torch.nn import KLDivLoss, Softmax
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
from torch_geometric.graphgym.config import cfg
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
|
||||||
from explaining_framework.metric.base import Metric
|
|
||||||
|
|
||||||
NUM_CLASS = cfg.share.dim_out
|
NUM_CLASS = cfg.share.dim_out
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,23 +57,30 @@ class Fidelity(Metric):
|
||||||
self._score_check()
|
self._score_check()
|
||||||
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
||||||
inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1)
|
inferred_class_exp = torch.argmax(self.s_exp_sub_c, dim=1)
|
||||||
return torch.mean(
|
return (
|
||||||
(exp.y == inferred_class_initial).float()
|
(
|
||||||
- (exp.y == inferred_class_exp).float()
|
(exp.y == inferred_class_initial).float()
|
||||||
).item()
|
- (exp.y == inferred_class_exp).float()
|
||||||
|
)
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
|
||||||
def _fidelity_minus(self, exp: Explanation) -> float:
|
def _fidelity_minus(self, exp: Explanation) -> float:
|
||||||
self._score_check()
|
self._score_check()
|
||||||
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
inferred_class_initial = torch.argmax(self.s_initial_data, dim=1)
|
||||||
inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1)
|
inferred_class_exp = torch.argmax(self.s_exp_sub, dim=1)
|
||||||
return torch.mean(
|
return (
|
||||||
(exp.y == inferred_class_initial).float()
|
(
|
||||||
- (exp.y == inferred_class_exp).float()
|
(exp.y == inferred_class_initial).float()
|
||||||
).item()
|
- (exp.y == inferred_class_exp).float()
|
||||||
|
)
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
|
||||||
def _fidelity_plus_prob(self, exp: Explanation) -> float:
|
def _fidelity_plus_prob(self, exp: Explanation) -> float:
|
||||||
self._score_check()
|
self._score_check()
|
||||||
# one_hot_emb = F.one_hot(exp.y, num_classes=NUM_CLASS)
|
|
||||||
prob_initial = softmax(self.s_initial_data)
|
prob_initial = softmax(self.s_initial_data)
|
||||||
prob_exp = softmax(self.s_exp_sub_c)
|
prob_exp = softmax(self.s_exp_sub_c)
|
||||||
|
|
||||||
|
@ -82,9 +88,7 @@ class Fidelity(Metric):
|
||||||
prob_initial = prob_initial[torch.arange(size), exp.y]
|
prob_initial = prob_initial[torch.arange(size), exp.y]
|
||||||
prob_exp = prob_exp[torch.arange(size), exp.y]
|
prob_exp = prob_exp[torch.arange(size), exp.y]
|
||||||
|
|
||||||
return torch.mean(
|
return (prob_initial - prob_exp).mean().item()
|
||||||
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
|
||||||
).item()
|
|
||||||
|
|
||||||
def _fidelity_minus_prob(self, exp: Explanation) -> float:
|
def _fidelity_minus_prob(self, exp: Explanation) -> float:
|
||||||
self._score_check()
|
self._score_check()
|
||||||
|
@ -95,9 +99,7 @@ class Fidelity(Metric):
|
||||||
prob_initial = prob_initial[torch.arange(size), exp.y]
|
prob_initial = prob_initial[torch.arange(size), exp.y]
|
||||||
prob_exp = prob_exp[torch.arange(size), exp.y]
|
prob_exp = prob_exp[torch.arange(size), exp.y]
|
||||||
|
|
||||||
return torch.mean(
|
return (prob_initial - prob_exp).mean().item()
|
||||||
torch.norm(1 - prob_initial, p=1) - torch.norm(1 - prob_exp, p=1)
|
|
||||||
).item()
|
|
||||||
|
|
||||||
def _infidelity_KL(self, exp: Explanation) -> float:
|
def _infidelity_KL(self, exp: Explanation) -> float:
|
||||||
self._score_check()
|
self._score_check()
|
||||||
|
@ -191,6 +193,13 @@ class Fidelity(Metric):
|
||||||
raise ValueError(f"{name} is not supported")
|
raise ValueError(f"{name} is not supported")
|
||||||
return self.metric
|
return self.metric
|
||||||
|
|
||||||
|
def reset_score(self):
|
||||||
|
self.exp_sub = None
|
||||||
|
self.exp_sub_c = None
|
||||||
|
self.s_exp_sub = None
|
||||||
|
self.s_exp_sub_c = None
|
||||||
|
self.s_initial_data = None
|
||||||
|
|
||||||
def forward(self, exp: Explanation):
|
def forward(self, exp: Explanation):
|
||||||
self.score(exp)
|
self.score(exp)
|
||||||
return self.metric(exp)
|
return self.metric(exp)
|
||||||
|
|
|
@ -3,11 +3,13 @@ import copy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
from torch_geometric.graphgym.config import cfg
|
from torch_geometric.graphgym.config import cfg
|
||||||
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
||||||
|
|
||||||
from explaining_framework.metric.base import Metric
|
from explaining_framework.metric.base import Metric
|
||||||
|
from explaining_framework.utils.io import obj_config_to_str
|
||||||
|
|
||||||
|
|
||||||
def compute_gradient(model, inp, target, loss):
|
def compute_gradient(model, inp, target, loss):
|
||||||
|
@ -18,121 +20,6 @@ def compute_gradient(model, inp, target, loss):
|
||||||
return torch.autograd.grad(err, inp.x)[0]
|
return torch.autograd.grad(err, inp.x)[0]
|
||||||
|
|
||||||
|
|
||||||
class FGSM(Metric):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: torch.nn.Module,
|
|
||||||
loss: torch.nn.Module,
|
|
||||||
lower_bound: float = float("-inf"),
|
|
||||||
upper_bound: float = float("inf"),
|
|
||||||
):
|
|
||||||
super().__init__(name="fgsm", model=model)
|
|
||||||
self.model = model
|
|
||||||
self.loss = loss
|
|
||||||
self.lower_bound = lower_bound
|
|
||||||
self.upper_bound = upper_bound
|
|
||||||
|
|
||||||
self.bound = lambda x: torch.clamp(
|
|
||||||
x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound])
|
|
||||||
)
|
|
||||||
|
|
||||||
self.zero_thresh = 10**-6
|
|
||||||
|
|
||||||
def forward(self, input, target, epsilon: float) -> Explanation:
|
|
||||||
input_ = input.clone()
|
|
||||||
grad = compute_gradient(
|
|
||||||
model=self.model, inp=input_, target=target, loss=self.loss
|
|
||||||
)
|
|
||||||
grad = self.bound(grad)
|
|
||||||
input_.x = torch.where(
|
|
||||||
torch.abs(grad) > self.zero_thresh,
|
|
||||||
input_.x - epsilon * torch.sign(grad),
|
|
||||||
input_.x,
|
|
||||||
)
|
|
||||||
return input_
|
|
||||||
|
|
||||||
def load_metric(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PGD(Metric):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: torch.nn.Module,
|
|
||||||
loss: torch.nn.Module,
|
|
||||||
lower_bound: float = float("-inf"),
|
|
||||||
upper_bound: float = float("inf"),
|
|
||||||
):
|
|
||||||
super().__init__(name="pgd", model=model)
|
|
||||||
self.model = model
|
|
||||||
self.loss = loss
|
|
||||||
self.lower_bound = lower_bound
|
|
||||||
self.upper_bound = upper_bound
|
|
||||||
self.bound = lambda x: torch.clamp(
|
|
||||||
x, min=torch.Tensor([lower_bound]), max=torch.Tensor([upper_bound])
|
|
||||||
)
|
|
||||||
self.zero_thresh = 10**-6
|
|
||||||
self.fgsm = FGSM(
|
|
||||||
model=model, loss=loss, lower_bound=lower_bound, upper_bound=upper_bound
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input,
|
|
||||||
target,
|
|
||||||
epsilon: float,
|
|
||||||
radius: float,
|
|
||||||
step_num: int,
|
|
||||||
random_start: bool = False,
|
|
||||||
norm: str = "inf",
|
|
||||||
) -> Explanation:
|
|
||||||
def _clip(inputs: Explanation, outputs: Explanation) -> Explanation:
|
|
||||||
diff = outputs.x - inputs.x
|
|
||||||
if norm == "inf":
|
|
||||||
inputs.x = inputs.x + torch.clamp(diff, -radius, radius)
|
|
||||||
return inputs
|
|
||||||
elif norm == "2":
|
|
||||||
inputs.x = inputs.x + torch.renorm(diff, 2, 0, radius)
|
|
||||||
return inputs
|
|
||||||
else:
|
|
||||||
raise AssertionError("Norm constraint must be L2 or Linf.")
|
|
||||||
|
|
||||||
perturbed_input = input
|
|
||||||
if random_start:
|
|
||||||
perturbed_input = self.bound(self._random_point(input.x, radius, norm))
|
|
||||||
for _ in range(step_num):
|
|
||||||
perturbed_input = self.fgsm.forward(
|
|
||||||
input=perturbed_input, epsilon=epsilon, target=target
|
|
||||||
)
|
|
||||||
perturbed_input = _clip(input, perturbed_input)
|
|
||||||
perturbed_input.x = self.bound(perturbed_input.x).detach()
|
|
||||||
return perturbed_input
|
|
||||||
|
|
||||||
def load_metric(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _random_point(
|
|
||||||
self, center: torch.Tensor, radius: float, norm: str
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
|
||||||
A helper function that returns a uniform random point within the ball
|
|
||||||
with the given center and radius. Norm should be either L2 or Linf.
|
|
||||||
"""
|
|
||||||
if norm == "2":
|
|
||||||
u = torch.randn_like(center)
|
|
||||||
unit_u = F.normalize(u.view(u.size(0), -1)).view(u.size())
|
|
||||||
d = torch.numel(center[0])
|
|
||||||
r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius
|
|
||||||
r = r[(...,) + (None,) * (r.dim() - 1)]
|
|
||||||
x = r * unit_u
|
|
||||||
return center + x
|
|
||||||
elif norm == "inf":
|
|
||||||
x = torch.rand_like(center) * radius * 2 - radius
|
|
||||||
return center + x
|
|
||||||
else:
|
|
||||||
raise AssertionError("Norm constraint must be L2 or Linf.")
|
|
||||||
|
|
||||||
|
|
||||||
class Attack(Metric):
|
class Attack(Metric):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -152,8 +39,10 @@ class Attack(Metric):
|
||||||
"remove_node",
|
"remove_node",
|
||||||
"pgd",
|
"pgd",
|
||||||
"fgsm",
|
"fgsm",
|
||||||
|
"no_attack",
|
||||||
]
|
]
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
self.config = None
|
||||||
if loss is None:
|
if loss is None:
|
||||||
if cfg.model.loss_fun == "cross_entropy":
|
if cfg.model.loss_fun == "cross_entropy":
|
||||||
self.loss = CrossEntropyLoss()
|
self.loss = CrossEntropyLoss()
|
||||||
|
@ -166,10 +55,8 @@ class Attack(Metric):
|
||||||
self.load_metric(name)
|
self.load_metric(name)
|
||||||
|
|
||||||
def _gaussian_noise(self, exp) -> Explanation:
|
def _gaussian_noise(self, exp) -> Explanation:
|
||||||
x = torch.clone(exp.x)
|
|
||||||
x = x + torch.randn(*x.shape)
|
|
||||||
exp_ = exp.clone()
|
exp_ = exp.clone()
|
||||||
exp_.x = x
|
exp_.x = exp_.x + torch.randn(*exp_.x.shape).to(exp_.x.device)
|
||||||
return exp_
|
return exp_
|
||||||
|
|
||||||
def _add_edge(self, exp, p: float) -> Explanation:
|
def _add_edge(self, exp, p: float) -> Explanation:
|
||||||
|
@ -203,10 +90,15 @@ class Attack(Metric):
|
||||||
def _load_gaussian_noise(self):
|
def _load_gaussian_noise(self):
|
||||||
return lambda exp: self._gaussian_noise(exp)
|
return lambda exp: self._gaussian_noise(exp)
|
||||||
|
|
||||||
|
def _load_no_attack(self):
|
||||||
|
return lambda exp: exp
|
||||||
|
|
||||||
def load_metric(self, name):
|
def load_metric(self, name):
|
||||||
if name in self.authorized_metric:
|
if name in self.authorized_metric:
|
||||||
if name == "gaussian_noise":
|
if name == "gaussian_noise":
|
||||||
self.metric = self._load_gaussian_noise()
|
self.metric = self._load_gaussian_noise()
|
||||||
|
if name == "no_attack":
|
||||||
|
self.metric = self._load_no_attack()
|
||||||
if name == "add_edge":
|
if name == "add_edge":
|
||||||
self.metric = self._load_add_edge()
|
self.metric = self._load_add_edge()
|
||||||
if name == "remove_edge":
|
if name == "remove_edge":
|
||||||
|
@ -214,21 +106,24 @@ class Attack(Metric):
|
||||||
if name == "remove_node":
|
if name == "remove_node":
|
||||||
self.metric = self._load_remove_node()
|
self.metric = self._load_remove_node()
|
||||||
if name == "pgd":
|
if name == "pgd":
|
||||||
pgd = PGD(model=self.model, loss=self.loss)
|
pgd = PGD(
|
||||||
self.metric = lambda exp: pgd.forward(
|
model=self.model,
|
||||||
input=exp,
|
loss=self.loss,
|
||||||
target=exp.y,
|
|
||||||
epsilon=1,
|
epsilon=1,
|
||||||
radius=1,
|
radius=1,
|
||||||
step_num=50,
|
step_num=50,
|
||||||
random_start=False,
|
random_start=False,
|
||||||
norm="inf",
|
norm="inf",
|
||||||
)
|
)
|
||||||
if name == "fgsm":
|
self.config = obj_config_to_str(pgd.__dict__)
|
||||||
fgsm = FGSM(model=self.model, loss=self.loss)
|
self.metric = lambda exp: pgd.forward(
|
||||||
self.metric = lambda exp: fgsm.forward(
|
input=exp,
|
||||||
input=exp, target=exp.y, epsilon=1
|
target=exp.y,
|
||||||
)
|
)
|
||||||
|
if name == "fgsm":
|
||||||
|
fgsm = FGSM(model=self.model, loss=self.loss, epsilon=1)
|
||||||
|
self.config = obj_config_to_str(fgsm.__dict__)
|
||||||
|
self.metric = lambda exp: fgsm.forward(input=exp, target=exp.y)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{name} is not supported yet")
|
raise ValueError(f"{name} is not supported yet")
|
||||||
|
|
||||||
|
@ -237,3 +132,120 @@ class Attack(Metric):
|
||||||
def forward(self, exp) -> Explanation:
|
def forward(self, exp) -> Explanation:
|
||||||
attack = self.metric(exp)
|
attack = self.metric(exp)
|
||||||
return attack
|
return attack
|
||||||
|
|
||||||
|
def get_attacked_prediction(self, data: Data) -> Data:
|
||||||
|
data_ = data.clone()
|
||||||
|
data_attacked = self.forward(data_)
|
||||||
|
pred = self.get_prediction(x=data_.x, edge_index=data_.edge_index)
|
||||||
|
pred_attacked = self.get_prediction(
|
||||||
|
x=data_attacked.x, edge_index=data_attacked.edge_index
|
||||||
|
)
|
||||||
|
setattr(data_, "pred", pred)
|
||||||
|
setattr(data_, "pred_attacked", pred_attacked)
|
||||||
|
return data_
|
||||||
|
|
||||||
|
|
||||||
|
class FGSM(Metric):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
loss: torch.nn.Module,
|
||||||
|
lower_bound: float = float("-inf"),
|
||||||
|
upper_bound: float = float("inf"),
|
||||||
|
epsilon=1,
|
||||||
|
):
|
||||||
|
super().__init__(name="fgsm", model=model)
|
||||||
|
self.model = model
|
||||||
|
self.loss = loss
|
||||||
|
self.lower_bound = lower_bound
|
||||||
|
self.upper_bound = upper_bound
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
self.bound = lambda x: torch.clamp(
|
||||||
|
x,
|
||||||
|
min=torch.Tensor([lower_bound]).to(x.device),
|
||||||
|
max=torch.Tensor([upper_bound]).to(x.device),
|
||||||
|
).to(x.device)
|
||||||
|
|
||||||
|
self.zero_thresh = 10**-6
|
||||||
|
|
||||||
|
def forward(self, input, target) -> Explanation:
|
||||||
|
input_ = input.clone()
|
||||||
|
grad = compute_gradient(
|
||||||
|
model=self.model, inp=input_, target=target, loss=self.loss
|
||||||
|
)
|
||||||
|
grad = self.bound(grad)
|
||||||
|
input_.x = torch.where(
|
||||||
|
torch.abs(grad) > self.zero_thresh,
|
||||||
|
input_.x - self.epsilon * torch.sign(grad),
|
||||||
|
input_.x,
|
||||||
|
)
|
||||||
|
return input_
|
||||||
|
|
||||||
|
def load_metric(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PGD(Metric):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
loss: torch.nn.Module,
|
||||||
|
lower_bound: float = float("-inf"),
|
||||||
|
upper_bound: float = float("inf"),
|
||||||
|
epsilon=1,
|
||||||
|
radius=1,
|
||||||
|
step_num=50,
|
||||||
|
random_start=False,
|
||||||
|
norm="inf",
|
||||||
|
):
|
||||||
|
super().__init__(name="pgd", model=model)
|
||||||
|
self.model = model
|
||||||
|
self.loss = loss
|
||||||
|
self.lower_bound = lower_bound
|
||||||
|
self.upper_bound = upper_bound
|
||||||
|
self.bound = lambda x: torch.clamp(
|
||||||
|
x,
|
||||||
|
min=torch.Tensor([lower_bound]).to(x.device),
|
||||||
|
max=torch.Tensor([upper_bound]).to(x.device),
|
||||||
|
).to(x.device)
|
||||||
|
|
||||||
|
self.zero_thresh = 10**-6
|
||||||
|
self.fgsm = FGSM(
|
||||||
|
model=model, loss=loss, lower_bound=lower_bound, upper_bound=upper_bound
|
||||||
|
)
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.radius = radius
|
||||||
|
self.step_num = step_num
|
||||||
|
self.random_start = random_start
|
||||||
|
self.norm = norm
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input,
|
||||||
|
target,
|
||||||
|
) -> Explanation:
|
||||||
|
def _clip(inputs: Explanation, outputs: Explanation) -> Explanation:
|
||||||
|
diff = outputs.x - inputs.x
|
||||||
|
if self.norm == "inf":
|
||||||
|
inputs.x = inputs.x + torch.clamp(diff, -self.radius, self.radius)
|
||||||
|
return inputs
|
||||||
|
elif self.norm == "2":
|
||||||
|
inputs.x = inputs.x + torch.renorm(diff, 2, 0, self.radius)
|
||||||
|
return inputs
|
||||||
|
else:
|
||||||
|
raise AssertionError("Norm constraint must be L2 or Linf.")
|
||||||
|
|
||||||
|
perturbed_input = input
|
||||||
|
if self.random_start:
|
||||||
|
perturbed_input = self.bound(
|
||||||
|
self._random_point(input.x, self.radius, self.norm)
|
||||||
|
)
|
||||||
|
for _ in range(self.step_num):
|
||||||
|
perturbed_input = self.fgsm.forward(input=perturbed_input, target=target)
|
||||||
|
perturbed_input = _clip(input, perturbed_input)
|
||||||
|
perturbed_input.x = self.bound(perturbed_input.x).detach()
|
||||||
|
return perturbed_input
|
||||||
|
|
||||||
|
def load_metric(self):
|
||||||
|
pass
|
||||||
|
|
|
@ -20,6 +20,6 @@ class Sparsity(Metric):
|
||||||
def forward(self, exp: Explanation) -> float:
|
def forward(self, exp: Explanation) -> float:
|
||||||
out = {}
|
out = {}
|
||||||
for k, v in exp.to_dict().items():
|
for k, v in exp.to_dict().items():
|
||||||
if "mask" in k and v.dtype == torch.bool:
|
if "mask" in k and torch.all(torch.logical_or(v == 0, v == 1)).item():
|
||||||
out[k] = torch.mean(mask.float()).item()
|
out[k] = torch.mean(v).item()
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -124,13 +124,12 @@ class LoadModelInfo(object):
|
||||||
|
|
||||||
model_name = os.path.basename(self.info["xp_dir_path"])
|
model_name = os.path.basename(self.info["xp_dir_path"])
|
||||||
model_seed = self.info["seed"]
|
model_seed = self.info["seed"]
|
||||||
epoch = os.path.basename(self.info["ckpt_path"])
|
|
||||||
model_signature = "-".join(
|
model_signature = "-".join(
|
||||||
[
|
[
|
||||||
f"{name}={val}"
|
f"{name}={val}"
|
||||||
for name, val in zip(["name", "seed"], [model_name, model_seed])
|
for name, val in zip(["name", "seed"], [model_name, model_seed])
|
||||||
]
|
]
|
||||||
+ [epoch]
|
+ [self.which]
|
||||||
)
|
)
|
||||||
return model_signature
|
return model_signature
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ from torch_geometric.graphgym.loader import create_dataset
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
from torch_geometric.loader.dataloader import DataLoader
|
from torch_geometric.loader.dataloader import DataLoader
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
|
@ -22,6 +23,7 @@ from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
||||||
from explaining_framework.explainers.wrappers.from_graphxai import \
|
from explaining_framework.explainers.wrappers.from_graphxai import \
|
||||||
GraphXAIWrapper
|
GraphXAIWrapper
|
||||||
|
from explaining_framework.explainers.wrappers.from_pyg import PYGWrapper
|
||||||
from explaining_framework.metric.accuracy import Accuracy
|
from explaining_framework.metric.accuracy import Accuracy
|
||||||
from explaining_framework.metric.base import Metric
|
from explaining_framework.metric.base import Metric
|
||||||
from explaining_framework.metric.fidelity import Fidelity
|
from explaining_framework.metric.fidelity import Fidelity
|
||||||
|
@ -47,7 +49,7 @@ all__captum = [
|
||||||
"GuidedBackprop",
|
"GuidedBackprop",
|
||||||
"GuidedGradCam",
|
"GuidedGradCam",
|
||||||
"InputXGradient",
|
"InputXGradient",
|
||||||
"IntegratedGradients",
|
# "IntegratedGradients",
|
||||||
"Lime",
|
"Lime",
|
||||||
"Occlusion",
|
"Occlusion",
|
||||||
"Saliency",
|
"Saliency",
|
||||||
|
@ -67,6 +69,10 @@ all__graphxai = [
|
||||||
"GraphMASK",
|
"GraphMASK",
|
||||||
"GNNExplainer",
|
"GNNExplainer",
|
||||||
]
|
]
|
||||||
|
all__pyg = [
|
||||||
|
# "PGExplainer",
|
||||||
|
# "GNNExplainer",
|
||||||
|
]
|
||||||
|
|
||||||
all__own = ["EIXGNN", "SCGNN"]
|
all__own = ["EIXGNN", "SCGNN"]
|
||||||
|
|
||||||
|
@ -94,10 +100,11 @@ all_robust = [
|
||||||
"remove_node",
|
"remove_node",
|
||||||
"pgd",
|
"pgd",
|
||||||
"fgsm",
|
"fgsm",
|
||||||
|
"no_attack",
|
||||||
]
|
]
|
||||||
all_sparsity = ["l0"]
|
all_sparsity = ["l0"]
|
||||||
|
|
||||||
adjust_pattern = "ranp"
|
adjust_pattern = "ranps"
|
||||||
all_adjusts_filters = [
|
all_adjusts_filters = [
|
||||||
"".join(filters)
|
"".join(filters)
|
||||||
for i in range(len(adjust_pattern) + 1)
|
for i in range(len(adjust_pattern) + 1)
|
||||||
|
@ -168,9 +175,9 @@ class ExplainingOutline(object):
|
||||||
|
|
||||||
def load_indexes(self):
|
def load_indexes(self):
|
||||||
|
|
||||||
items = self.explaining_cfg.dataset.items
|
item = self.explaining_cfg.dataset.item
|
||||||
if isinstance(items, (list, int)):
|
if isinstance(item, (list, int)):
|
||||||
indexes = items
|
indexes = item
|
||||||
else:
|
else:
|
||||||
indexes = list(range(len(self.dataset)))
|
indexes = list(range(len(self.dataset)))
|
||||||
self.indexes = iter(indexes)
|
self.indexes = iter(indexes)
|
||||||
|
@ -223,7 +230,7 @@ class ExplainingOutline(object):
|
||||||
elif self.explaining_cfg.explainer.name == "SCGNN":
|
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||||
self.explainer_cfg = copy.copy(scgnn_cfg)
|
self.explainer_cfg = copy.copy(scgnn_cfg)
|
||||||
else:
|
else:
|
||||||
self.explainer_cfg = None
|
self.explainer_cfg = CN()
|
||||||
else:
|
else:
|
||||||
if self.explaining_cfg.explainer.name == "EIXGNN":
|
if self.explaining_cfg.explainer.name == "EIXGNN":
|
||||||
eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||||
|
@ -241,6 +248,7 @@ class ExplainingOutline(object):
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise ValueError("Model ckpt has not been loaded, ckpt file not found")
|
raise ValueError("Model ckpt has not been loaded, ckpt file not found")
|
||||||
self.model = self.model.eval()
|
self.model = self.model.eval()
|
||||||
|
self.model.explain = True
|
||||||
|
|
||||||
def load_dataset(self):
|
def load_dataset(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
|
@ -252,19 +260,26 @@ class ExplainingOutline(object):
|
||||||
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
||||||
)
|
)
|
||||||
self.dataset = create_dataset()
|
self.dataset = create_dataset()
|
||||||
items = self.explaining_cfg.dataset.items
|
item = self.explaining_cfg.dataset.item
|
||||||
print(items)
|
if isinstance(item, int):
|
||||||
print(type(items))
|
self.dataset = self.dataset[item : item + 1]
|
||||||
if isinstance(items, int):
|
elif isinstance(item, list):
|
||||||
self.dataset = self.dataset[items : items + 1]
|
self.dataset = self.dataset[item]
|
||||||
elif isinstance(items, list):
|
|
||||||
self.dataset = self.dataset[items]
|
|
||||||
|
|
||||||
def load_dataset_to_dataloader(self, to_iter=True):
|
def load_dataset_to_dataloader(self, to_iter=True):
|
||||||
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
||||||
if to_iter:
|
if to_iter:
|
||||||
self.dataset = iter(self.dataset)
|
self.dataset = iter(self.dataset)
|
||||||
|
|
||||||
|
def reload_dataset(self):
|
||||||
|
self.load_dataset()
|
||||||
|
self.load_indexes()
|
||||||
|
|
||||||
|
def reload_dataloader(self):
|
||||||
|
self.load_dataset()
|
||||||
|
self.load_dataset_to_dataloader()
|
||||||
|
self.load_indexes()
|
||||||
|
|
||||||
def load_explaining_algorithm(self):
|
def load_explaining_algorithm(self):
|
||||||
self.load_explainer_cfg()
|
self.load_explainer_cfg()
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
|
@ -273,14 +288,16 @@ class ExplainingOutline(object):
|
||||||
self.load_dataset()
|
self.load_dataset()
|
||||||
|
|
||||||
name = self.explaining_cfg.explainer.name
|
name = self.explaining_cfg.explainer.name
|
||||||
if name in all__captum:
|
if name in all__graphxai:
|
||||||
explaining_algorithm = CaptumWrapper(name)
|
|
||||||
elif name in all__graphxai:
|
|
||||||
explaining_algorithm = GraphXAIWrapper(
|
explaining_algorithm = GraphXAIWrapper(
|
||||||
name,
|
name,
|
||||||
in_channels=self.dataset.num_classes,
|
in_channels=self.dataset.num_classes,
|
||||||
criterion=self.cfg.model.loss_fun,
|
criterion=self.cfg.model.loss_fun,
|
||||||
)
|
)
|
||||||
|
elif name in all__captum:
|
||||||
|
explaining_algorithm = CaptumWrapper(name)
|
||||||
|
elif name in all__pyg:
|
||||||
|
explaining_algorithm = PYGWrapper(name)
|
||||||
elif name in all__own:
|
elif name in all__own:
|
||||||
if name == "EIXGNN":
|
if name == "EIXGNN":
|
||||||
explaining_algorithm = EiXGNN(
|
explaining_algorithm = EiXGNN(
|
||||||
|
@ -296,6 +313,7 @@ class ExplainingOutline(object):
|
||||||
depth=self.explainer_cfg.depth,
|
depth=self.explainer_cfg.depth,
|
||||||
interest_map_norm=self.explainer_cfg.interest_map_norm,
|
interest_map_norm=self.explainer_cfg.interest_map_norm,
|
||||||
score_map_norm=self.explainer_cfg.score_map_norm,
|
score_map_norm=self.explainer_cfg.score_map_norm,
|
||||||
|
target_baseline=self.explainer_cfg.target_baseline,
|
||||||
)
|
)
|
||||||
elif name is None:
|
elif name is None:
|
||||||
explaining_algorithm = None
|
explaining_algorithm = None
|
||||||
|
@ -539,6 +557,7 @@ class ExplainingOutline(object):
|
||||||
explanation = _get_explanation(self.explainer, item)
|
explanation = _get_explanation(self.explainer, item)
|
||||||
else:
|
else:
|
||||||
explanation = _load_explanation(path)
|
explanation = _load_explanation(path)
|
||||||
|
explanation = explanation.to(self.cfg.accelerator)
|
||||||
else:
|
else:
|
||||||
explanation = _get_explanation(self.explainer, item)
|
explanation = _get_explanation(self.explainer, item)
|
||||||
get_pred(self.explainer, explanation)
|
get_pred(self.explainer, explanation)
|
||||||
|
@ -590,3 +609,14 @@ class ExplainingOutline(object):
|
||||||
if item.num_nodes <= 500:
|
if item.num_nodes <= 500:
|
||||||
stat = self.graphstat(item)
|
stat = self.graphstat(item)
|
||||||
write_json(stat, path)
|
write_json(stat, path)
|
||||||
|
|
||||||
|
def get_attack(self, attack: Attack, item: Data, path: str):
|
||||||
|
if is_exists(path):
|
||||||
|
if self.explaining_cfg.explainer.force:
|
||||||
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
|
else:
|
||||||
|
data_attack = _load_explanation(path)
|
||||||
|
else:
|
||||||
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
|
_save_explanation(data_attack, path)
|
||||||
|
return data_attack
|
||||||
|
|
|
@ -9,37 +9,46 @@ from torch_geometric.explain.explanation import Explanation
|
||||||
class Adjust(object):
|
class Adjust(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
strategy: str = "rpn",
|
strategy: str = "rpns",
|
||||||
):
|
):
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
|
|
||||||
def forward(self, exp: Explanation) -> Explanation:
|
def forward(self, exp: Explanation) -> Explanation:
|
||||||
exp_ = exp.clone()
|
exp_ = exp.clone()
|
||||||
_store = exp_.to_dict()
|
for k, v in exp_.items():
|
||||||
for k, v in _store.items():
|
|
||||||
if "mask" in k:
|
if "mask" in k:
|
||||||
for f_ in self.strategy:
|
for f_ in self.strategy:
|
||||||
if f_ == "r":
|
if f_ == "r":
|
||||||
_store[k] = self.relu(v)
|
exp_.__setattr__(k, self.relu(v))
|
||||||
if f_ == "a":
|
if f_ == "a":
|
||||||
_store[k] = self.absolute(v)
|
exp_.__setattr__(k, self.absolute(v))
|
||||||
if f_ == "p":
|
if f_ == "p":
|
||||||
if "edge" in k:
|
if "edge" in k:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
_store[k] = self.project(v)
|
exp_.__setattr__(k, self.project(v))
|
||||||
if f_ == "n":
|
if f_ == "n":
|
||||||
_store[k] = self.normalize(v)
|
exp_.__setattr__(k, self.normalize(v))
|
||||||
|
if f_ == "s":
|
||||||
|
exp_.__setattr__(k, self.squeeze_(v))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return exp_
|
return exp_
|
||||||
|
|
||||||
def relu(self, mask: FloatTensor) -> FloatTensor:
|
def relu(self, mask: FloatTensor) -> FloatTensor:
|
||||||
relu = ReLU()
|
relu = ReLU(inplace=True)
|
||||||
mask_ = relu(mask)
|
mask_ = relu(mask)
|
||||||
return mask_
|
return mask_
|
||||||
|
|
||||||
|
def squeeze_(self, mask: FloatTensor) -> FloatTensor:
|
||||||
|
if mask.max() == mask.min():
|
||||||
|
return mask
|
||||||
|
else:
|
||||||
|
mask_ = (mask - mask.min()).div(mask.max() - mask.min())
|
||||||
|
return mask_
|
||||||
|
|
||||||
def normalize(self, mask: FloatTensor) -> FloatTensor:
|
def normalize(self, mask: FloatTensor) -> FloatTensor:
|
||||||
norm = torch.norm(mask, p=float("inf"))
|
norm = torch.norm(mask, p=float("inf"))
|
||||||
if norm.item() > 0:
|
if norm.item() > 0:
|
||||||
|
|
|
@ -26,22 +26,45 @@ def write_yaml(data: dict, path: str) -> None:
|
||||||
data = yaml.dump(data, f)
|
data = yaml.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_cfg(cfg, path):
|
||||||
|
r"""
|
||||||
|
Dumps the config to the output directory specified in
|
||||||
|
:obj:`cfg.out_dir`
|
||||||
|
Args:
|
||||||
|
cfg (CfgNode): Configuration node
|
||||||
|
"""
|
||||||
|
with open(path, "w") as f:
|
||||||
|
cfg.dump(stream=f)
|
||||||
|
|
||||||
|
|
||||||
def is_exists(path: str) -> bool:
|
def is_exists(path: str) -> bool:
|
||||||
return os.path.exists(path)
|
return os.path.exists(path)
|
||||||
|
|
||||||
|
|
||||||
def get_obj_config(obj):
|
def get_dict_config(d: dict):
|
||||||
config = {
|
config = {}
|
||||||
k: v for k, v in obj.__dict__.items() if isinstance(v, (int, float, str, bool))
|
for k, v in d.items():
|
||||||
}
|
if isinstance(v, (int, float, str, bool)):
|
||||||
|
config[k] = val_check(v)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def val_check(v):
|
||||||
|
if v == float("-inf"):
|
||||||
|
return "minus_inf"
|
||||||
|
else:
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
def save_obj_config(obj, path) -> None:
|
def save_obj_config(obj, path) -> None:
|
||||||
config = get_obj_config(obj)
|
config = get_obj_config(obj)
|
||||||
write_json(config, path)
|
write_json(config, path)
|
||||||
|
|
||||||
|
|
||||||
def obj_config_to_str(obj) -> str:
|
def obj_config_to_str(obj) -> str:
|
||||||
config = get_obj_config(obj)
|
if isinstance(obj, dict):
|
||||||
return "-".join([f"{k}={v}" for k, v in config.items()])
|
config = get_dict_config(obj)
|
||||||
|
return "-".join([f"{k}={v}" for k, v in config.items()])
|
||||||
|
else:
|
||||||
|
config = get_dict_config(obj.__dict__)
|
||||||
|
return "-".join([f"{k}={v}" for k, v in config.items()])
|
||||||
|
|
128
main.py
128
main.py
|
@ -18,8 +18,9 @@ from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||||
from explaining_framework.utils.explanation.adjust import Adjust
|
from explaining_framework.utils.explanation.adjust import Adjust
|
||||||
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
read_json, write_json, write_yaml)
|
obj_config_to_str, read_json,
|
||||||
|
write_json)
|
||||||
|
|
||||||
# inference, time, force,
|
# inference, time, force,
|
||||||
|
|
||||||
|
@ -27,65 +28,100 @@ from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||||
print(outline.explaining_cfg)
|
out_dir = os.path.join(
|
||||||
|
outline.explaining_cfg.out_dir,
|
||||||
out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature)
|
outline.cfg.dataset.name,
|
||||||
|
outline.model_signature,
|
||||||
|
)
|
||||||
makedirs(out_dir)
|
makedirs(out_dir)
|
||||||
|
|
||||||
write_yaml(outline.cfg, os.path.join(out_dir, "config.yaml"))
|
dump_cfg(outline.cfg, os.path.join(out_dir, "config.yaml"))
|
||||||
write_json(outline.model_info, os.path.join(out_dir, "info.json"))
|
write_json(outline.model_info, os.path.join(out_dir, "info.json"))
|
||||||
|
|
||||||
explainer_path = os.path.join(
|
explainer_path = os.path.join(
|
||||||
out_dir,
|
out_dir,
|
||||||
outline.explaining_cfg.explainer.name
|
outline.explaining_cfg.explainer.name,
|
||||||
+ "_"
|
obj_config_to_str(outline.explaining_algorithm),
|
||||||
+ obj_config_to_str(outline.explaining_algorithm),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
makedirs(explainer_path)
|
makedirs(explainer_path)
|
||||||
write_yaml(
|
dump_cfg(
|
||||||
outline.explaining_cfg, os.path.join(explainer_path, explaining_cfg.cfg_dest)
|
outline.explainer_cfg,
|
||||||
|
os.path.join(explainer_path, "explainer_cfg.yaml"),
|
||||||
)
|
)
|
||||||
write_yaml(
|
dump_cfg(
|
||||||
outline.explainer_cfg, os.path.join(explainer_path, "explainer_cfg.yaml")
|
outline.explaining_cfg,
|
||||||
|
os.path.join(explainer_path, explaining_cfg.cfg_dest),
|
||||||
)
|
)
|
||||||
|
|
||||||
specific_explainer_path = os.path.join(
|
|
||||||
explainer_path, obj_config_to_str(outline.explaining_algorithm)
|
|
||||||
)
|
|
||||||
makedirs(specific_explainer_path)
|
|
||||||
|
|
||||||
raw_path = os.path.join(specific_explainer_path, "raw")
|
|
||||||
makedirs(raw_path)
|
|
||||||
|
|
||||||
item, index = outline.get_item()
|
item, index = outline.get_item()
|
||||||
while not (item is None or index is None):
|
while not (item is None or index is None):
|
||||||
explanation_path = os.path.join(raw_path, f"{index}.json")
|
for attack in outline.attacks:
|
||||||
raw_exp = outline.get_explanation(item=item, path=explanation_path)
|
attack_path = os.path.join(
|
||||||
for adjust in outline.adjusts:
|
out_dir, attack.__class__.__name__, obj_config_to_str(attack)
|
||||||
adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}")
|
|
||||||
makedirs(adjust_path)
|
|
||||||
exp_adjust_path = os.path.join(adjust_path, f"{index}.json")
|
|
||||||
exp_adjust = outline.get_adjust(
|
|
||||||
adjust=adjust, item=raw_exp, path=exp_adjust_path
|
|
||||||
)
|
)
|
||||||
for threshold_conf in outline.thresholds_configs:
|
makedirs(attack_path)
|
||||||
outline.set_explainer_threshold_config(threshold_conf)
|
data_attack_path = os.path.join(attack_path, f"{index}.json")
|
||||||
masking_path = os.path.join(
|
data_attack = outline.get_attack(
|
||||||
adjust_path,
|
attack=attack, item=item, path=data_attack_path
|
||||||
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
|
)
|
||||||
|
item, index = outline.get_item()
|
||||||
|
|
||||||
|
outline.reload_dataloader()
|
||||||
|
makedirs(explainer_path)
|
||||||
|
|
||||||
|
item, index = outline.get_item()
|
||||||
|
while not (item is None or index is None):
|
||||||
|
for attack in outline.attacks:
|
||||||
|
attack_path_ = os.path.join(
|
||||||
|
explainer_path, attack.__class__.__name__, obj_config_to_str(attack)
|
||||||
|
)
|
||||||
|
makedirs(attack_path_)
|
||||||
|
data_attack_path_ = os.path.join(attack_path_, f"{index}.json")
|
||||||
|
attack_data = outline.get_attack(
|
||||||
|
attack=attack, item=item, path=data_attack_path_
|
||||||
|
)
|
||||||
|
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
||||||
|
for adjust in outline.adjusts:
|
||||||
|
adjust_path = os.path.join(
|
||||||
|
attack_path_, adjust.__class__.__name__, obj_config_to_str(adjust)
|
||||||
)
|
)
|
||||||
makedirs(masking_path)
|
makedirs(adjust_path)
|
||||||
exp_masked_path = os.path.join(masking_path, f"{index}.json")
|
exp_adjust_path = os.path.join(adjust_path, f"{index}.json")
|
||||||
exp_masked = outline.get_threshold(
|
exp_adjust = outline.get_adjust(
|
||||||
item=exp_adjust, path=exp_masked_path
|
adjust=adjust, item=exp, path=exp_adjust_path
|
||||||
)
|
)
|
||||||
for metric in outline.metrics:
|
for threshold_conf in outline.thresholds_configs:
|
||||||
metric_path = os.path.join(
|
outline.set_explainer_threshold_config(threshold_conf)
|
||||||
masking_path, f"{obj_config_to_str(metric)}"
|
masking_path = os.path.join(
|
||||||
|
adjust_path,
|
||||||
|
"ThresholdConfig",
|
||||||
|
obj_config_to_str(threshold_conf),
|
||||||
)
|
)
|
||||||
makedirs(metric_path)
|
makedirs(masking_path)
|
||||||
metric_path = os.path.join(metric_path, f"{index}.json")
|
exp_masked_path = os.path.join(masking_path, f"{index}.json")
|
||||||
out_metric = outline.get_metric(
|
exp_masked = outline.get_threshold(
|
||||||
metric=metric, item=exp_masked, path=metric_path
|
item=exp_adjust, path=exp_masked_path
|
||||||
)
|
)
|
||||||
|
for metric in outline.metrics:
|
||||||
|
metric_path = os.path.join(
|
||||||
|
masking_path,
|
||||||
|
metric.__class__.__name__,
|
||||||
|
obj_config_to_str(metric),
|
||||||
|
)
|
||||||
|
makedirs(metric_path)
|
||||||
|
metric_path = os.path.join(metric_path, f"{index}.json")
|
||||||
|
out_metric = outline.get_metric(
|
||||||
|
metric=metric, item=exp_masked, path=metric_path
|
||||||
|
)
|
||||||
|
print("#################################")
|
||||||
|
print("Attack", attack.name)
|
||||||
|
print(
|
||||||
|
"ThresholdConfig",
|
||||||
|
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
|
||||||
|
)
|
||||||
|
print("Metric", metric.name)
|
||||||
|
print("Val", out_metric)
|
||||||
|
print("Index", index)
|
||||||
|
print("#################################")
|
||||||
|
|
||||||
|
item, index = outline.get_item()
|
||||||
|
|
Loading…
Reference in New Issue