Adding new features
This commit is contained in:
parent
e2d47af072
commit
a00e73d4f0
|
@ -119,15 +119,20 @@ def set_cfg(explaining_cfg):
|
|||
explaining_cfg.threshold_config.relu_and_normalize = True
|
||||
|
||||
# Select device: 'cpu', 'cuda', 'auto'
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
# which objectives metrics to computes, either all or one in particular if implemented
|
||||
explaining_cfg.metrics = CN()
|
||||
explaining_cfg.metrics.type = "all"
|
||||
explaining_cfg.metrics.name = "all"
|
||||
|
||||
# Whether or not recomputing metrics if they already exist
|
||||
explaining_cfg.metrics.force = False
|
||||
|
||||
explaining_cfg.attack = CN()
|
||||
explaining_cfg.attack.name = 'all'
|
||||
|
||||
|
||||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
|
||||
def assert_cfg(explaining_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs"""
|
||||
|
|
|
@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
|
|||
import torch
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
|
||||
from explaining_framework.utils.io import write_json
|
||||
|
||||
|
||||
class Metric(ABC):
|
||||
def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
|
||||
|
@ -46,3 +48,12 @@ class Metric(ABC):
|
|||
self.model.train(training)
|
||||
|
||||
return out
|
||||
|
||||
def save_config(self, path) -> None:
|
||||
config = {k: getattr(self, k) for k in dir(self)}
|
||||
config = {
|
||||
k: v
|
||||
for k, v in config.items()
|
||||
if isinstance(v, (int, float, str, bool)) or v is None
|
||||
}
|
||||
write_json(config, path)
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from explaining_framework.metric.base import Metric
|
||||
from torch.nn import KLDivLoss, Softmax
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
|
||||
print("NUM CLASSES cfg dataset")
|
||||
NUM_CLASS = 5
|
||||
from explaining_framework.metric.base import Metric
|
||||
|
||||
NUM_CLASS = cfg.share.dim_out
|
||||
|
||||
|
||||
def softmax(data):
|
||||
|
@ -26,6 +27,8 @@ class Fidelity(Metric):
|
|||
"fidelity_plus_prob",
|
||||
"fidelity_minus_prob",
|
||||
"infidelity_KL",
|
||||
"characterization",
|
||||
"characterization_prob",
|
||||
]
|
||||
|
||||
self.exp_sub = None
|
||||
|
@ -99,9 +102,42 @@ class Fidelity(Metric):
|
|||
self._score_check()
|
||||
prob_initial = softmax(self.s_initial_data)
|
||||
prob_exp = F.log_softmax(self.s_exp_sub, dim=1)
|
||||
print(prob_initial, prob_exp)
|
||||
return (1 - torch.exp(-kl(prob_exp, prob_initial))).item()
|
||||
|
||||
def _characterization_prob(
|
||||
self,
|
||||
exp: Explanation,
|
||||
pos_weight: float = 0.5,
|
||||
neg_weight: float = 0.5,
|
||||
) -> Tensor:
|
||||
if (pos_weight + neg_weight) != 1.0:
|
||||
raise ValueError(
|
||||
f"The weights need to sum up to 1 "
|
||||
f"(got {pos_weight} and {neg_weight})"
|
||||
)
|
||||
pos_fidelity = self._fidelity_plus_prob(exp)
|
||||
neg_fidelity = self._fidelity_minus_prob(exp)
|
||||
|
||||
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||
return 1.0 / denom
|
||||
|
||||
def _characterization(
|
||||
self,
|
||||
exp: Explanation,
|
||||
pos_weight: float = 0.5,
|
||||
neg_weight: float = 0.5,
|
||||
) -> Tensor:
|
||||
if (pos_weight + neg_weight) != 1.0:
|
||||
raise ValueError(
|
||||
f"The weights need to sum up to 1 "
|
||||
f"(got {pos_weight} and {neg_weight})"
|
||||
)
|
||||
pos_fidelity = self._fidelity_plus(exp)
|
||||
neg_fidelity = self._fidelity_minus(exp)
|
||||
|
||||
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||
return 1.0 / denom
|
||||
|
||||
def score(self, exp):
|
||||
self.exp_sub = exp.get_explanation_subgraph()
|
||||
self.exp_sub_c = exp.get_complement_subgraph()
|
||||
|
@ -125,6 +161,10 @@ class Fidelity(Metric):
|
|||
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
||||
if name == "infidelity_KL":
|
||||
self.metric = lambda exp: self._infidelity_KL(exp)
|
||||
if name == "characterization":
|
||||
self.metric = lambda exp: self._characterization(exp)
|
||||
if name == "characterization_prob":
|
||||
self.metric = lambda exp: self._characterization_prob(exp)
|
||||
else:
|
||||
raise ValueError(f"{name} is not supported")
|
||||
return self.metric
|
||||
|
|
|
@ -18,7 +18,7 @@ def compute_gradient(model, inp, target, loss):
|
|||
return torch.autograd.grad(err, inp.x)[0]
|
||||
|
||||
|
||||
class FGSM(object):
|
||||
class FGSM(Metric):
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
|
@ -26,6 +26,7 @@ class FGSM(object):
|
|||
lower_bound: float = float("-inf"),
|
||||
upper_bound: float = float("inf"),
|
||||
):
|
||||
super().__init__(name=name, model=model)
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lower_bound = lower_bound
|
||||
|
@ -51,7 +52,7 @@ class FGSM(object):
|
|||
return input_
|
||||
|
||||
|
||||
class PGD(object):
|
||||
class PGD(Metric):
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
|
@ -59,6 +60,7 @@ class PGD(object):
|
|||
lower_bound: float = float("-inf"),
|
||||
upper_bound: float = float("inf"),
|
||||
):
|
||||
super().__init__(name=name, model=model)
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.lower_bound = lower_bound
|
||||
|
|
|
@ -92,6 +92,8 @@ class ExplainingOutline(object):
|
|||
self.model = None
|
||||
self.dataset = None
|
||||
self.model_info = None
|
||||
self.metrics = None
|
||||
self.attacks = None
|
||||
|
||||
self.load_explaining_cfg()
|
||||
self.load_model_info()
|
||||
|
@ -100,6 +102,8 @@ class ExplainingOutline(object):
|
|||
self.load_model()
|
||||
self.load_explainer_cfg()
|
||||
self.load_explainer()
|
||||
self.load_metric()
|
||||
self.load_attack()
|
||||
|
||||
def load_model_info(self):
|
||||
info = LoadModelInfo(
|
||||
|
@ -203,27 +207,36 @@ class ExplainingOutline(object):
|
|||
if self.explaining_cfg is None:
|
||||
self.load_explaining_cfg()
|
||||
|
||||
if self.explaining_cfg.metrics.type == "all":
|
||||
if self.explaining_cfg.dataset.name == "BASHAPES":
|
||||
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
||||
name_ = self.explaining_cfg.metrics.type
|
||||
|
||||
if name_ == "all":
|
||||
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
|
||||
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
||||
self.metrics = all_acc_metrics + all_fid_metrics
|
||||
|
||||
if self.explaining_cfg.dataset.name == "BASHAPES":
|
||||
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
||||
self.metrics = self.metrics + all_acc_metrics
|
||||
elif name_ in all_fidelity:
|
||||
self.metrics = [Fidelity(name_)]
|
||||
elif name_ in all_sparsity:
|
||||
self.metrics = [Sparsity(name_)]
|
||||
elif name_ in all_accuracy:
|
||||
if self.explaining_cfg.dataset.name == "BASHAPES":
|
||||
self.metrics = [Accuracy(name_)]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation"
|
||||
)
|
||||
|
||||
def load_attack(self):
|
||||
if self.cfg is None:
|
||||
self.load_cfg()
|
||||
if self.explaining_cfg is None:
|
||||
self.load_explaining_cfg()
|
||||
name_ = self.explaining_cfg.attack.name
|
||||
if name_ == "all":
|
||||
all_rob_metrics = [Attack(name) for name in all_robust]
|
||||
|
||||
|
||||
class FileManager(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def save(obj: Any, path: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
PATH = "config_exp.yaml"
|
||||
test = ExplainingOutline(explaining_cfg_path=PATH)
|
||||
self.attacks = all_rob_metrics
|
||||
if name_ in all_robust:
|
||||
self.attacks = [Attack(name_)]
|
||||
|
|
8
main.py
8
main.py
|
@ -0,0 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
import os
|
||||
from explaining_framework.config.explaining_config import explaining_cfg
|
||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||
from explaining_framework.utils.explaining.outline import parse_args
|
Loading…
Reference in New Issue