Adding new features

This commit is contained in:
araison 2022-12-29 23:29:32 +01:00
parent e2d47af072
commit a00e73d4f0
6 changed files with 103 additions and 24 deletions

View File

@ -119,15 +119,20 @@ def set_cfg(explaining_cfg):
explaining_cfg.threshold_config.relu_and_normalize = True
# Select device: 'cpu', 'cuda', 'auto'
explaining_cfg.accelerator = "auto"
# which objectives metrics to computes, either all or one in particular if implemented
explaining_cfg.metrics = CN()
explaining_cfg.metrics.type = "all"
explaining_cfg.metrics.name = "all"
# Whether or not recomputing metrics if they already exist
explaining_cfg.metrics.force = False
explaining_cfg.attack = CN()
explaining_cfg.attack.name = 'all'
explaining_cfg.accelerator = "auto"
def assert_cfg(explaining_cfg):
r"""Checks config values, do necessary post processing to the configs"""

View File

@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
import torch
from torch_geometric.explain.explanation import Explanation
from explaining_framework.utils.io import write_json
class Metric(ABC):
def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
@ -46,3 +48,12 @@ class Metric(ABC):
self.model.train(training)
return out
def save_config(self, path) -> None:
config = {k: getattr(self, k) for k in dir(self)}
config = {
k: v
for k, v in config.items()
if isinstance(v, (int, float, str, bool)) or v is None
}
write_json(config, path)

View File

@ -1,11 +1,12 @@
import torch
import torch.nn.functional as F
from explaining_framework.metric.base import Metric
from torch.nn import KLDivLoss, Softmax
from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg
print("NUM CLASSES cfg dataset")
NUM_CLASS = 5
from explaining_framework.metric.base import Metric
NUM_CLASS = cfg.share.dim_out
def softmax(data):
@ -26,6 +27,8 @@ class Fidelity(Metric):
"fidelity_plus_prob",
"fidelity_minus_prob",
"infidelity_KL",
"characterization",
"characterization_prob",
]
self.exp_sub = None
@ -99,9 +102,42 @@ class Fidelity(Metric):
self._score_check()
prob_initial = softmax(self.s_initial_data)
prob_exp = F.log_softmax(self.s_exp_sub, dim=1)
print(prob_initial, prob_exp)
return (1 - torch.exp(-kl(prob_exp, prob_initial))).item()
def _characterization_prob(
self,
exp: Explanation,
pos_weight: float = 0.5,
neg_weight: float = 0.5,
) -> Tensor:
if (pos_weight + neg_weight) != 1.0:
raise ValueError(
f"The weights need to sum up to 1 "
f"(got {pos_weight} and {neg_weight})"
)
pos_fidelity = self._fidelity_plus_prob(exp)
neg_fidelity = self._fidelity_minus_prob(exp)
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
return 1.0 / denom
def _characterization(
self,
exp: Explanation,
pos_weight: float = 0.5,
neg_weight: float = 0.5,
) -> Tensor:
if (pos_weight + neg_weight) != 1.0:
raise ValueError(
f"The weights need to sum up to 1 "
f"(got {pos_weight} and {neg_weight})"
)
pos_fidelity = self._fidelity_plus(exp)
neg_fidelity = self._fidelity_minus(exp)
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
return 1.0 / denom
def score(self, exp):
self.exp_sub = exp.get_explanation_subgraph()
self.exp_sub_c = exp.get_complement_subgraph()
@ -125,6 +161,10 @@ class Fidelity(Metric):
self.metric = lambda exp: self._fidelity_minus_prob(exp)
if name == "infidelity_KL":
self.metric = lambda exp: self._infidelity_KL(exp)
if name == "characterization":
self.metric = lambda exp: self._characterization(exp)
if name == "characterization_prob":
self.metric = lambda exp: self._characterization_prob(exp)
else:
raise ValueError(f"{name} is not supported")
return self.metric

View File

@ -18,7 +18,7 @@ def compute_gradient(model, inp, target, loss):
return torch.autograd.grad(err, inp.x)[0]
class FGSM(object):
class FGSM(Metric):
def __init__(
self,
model: torch.nn.Module,
@ -26,6 +26,7 @@ class FGSM(object):
lower_bound: float = float("-inf"),
upper_bound: float = float("inf"),
):
super().__init__(name=name, model=model)
self.model = model
self.loss = loss
self.lower_bound = lower_bound
@ -51,7 +52,7 @@ class FGSM(object):
return input_
class PGD(object):
class PGD(Metric):
def __init__(
self,
model: torch.nn.Module,
@ -59,6 +60,7 @@ class PGD(object):
lower_bound: float = float("-inf"),
upper_bound: float = float("inf"),
):
super().__init__(name=name, model=model)
self.model = model
self.loss = loss
self.lower_bound = lower_bound

View File

@ -92,6 +92,8 @@ class ExplainingOutline(object):
self.model = None
self.dataset = None
self.model_info = None
self.metrics = None
self.attacks = None
self.load_explaining_cfg()
self.load_model_info()
@ -100,6 +102,8 @@ class ExplainingOutline(object):
self.load_model()
self.load_explainer_cfg()
self.load_explainer()
self.load_metric()
self.load_attack()
def load_model_info(self):
info = LoadModelInfo(
@ -203,27 +207,36 @@ class ExplainingOutline(object):
if self.explaining_cfg is None:
self.load_explaining_cfg()
if self.explaining_cfg.metrics.type == "all":
if self.explaining_cfg.dataset.name == "BASHAPES":
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
name_ = self.explaining_cfg.metrics.type
if name_ == "all":
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
self.metrics = all_acc_metrics + all_fid_metrics
if self.explaining_cfg.dataset.name == "BASHAPES":
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
self.metrics = self.metrics + all_acc_metrics
elif name_ in all_fidelity:
self.metrics = [Fidelity(name_)]
elif name_ in all_sparsity:
self.metrics = [Sparsity(name_)]
elif name_ in all_accuracy:
if self.explaining_cfg.dataset.name == "BASHAPES":
self.metrics = [Accuracy(name_)]
else:
raise ValueError(
f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation"
)
def load_attack(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.load_explaining_cfg()
name_ = self.explaining_cfg.attack.name
if name_ == "all":
all_rob_metrics = [Attack(name) for name in all_robust]
class FileManager(object):
def __init__(self):
pass
def save(obj: Any, path: str) -> None:
pass
PATH = "config_exp.yaml"
test = ExplainingOutline(explaining_cfg_path=PATH)
self.attacks = all_rob_metrics
if name_ in all_robust:
self.attacks = [Attack(name_)]

View File

@ -0,0 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
import os
from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import parse_args