Adding new features
This commit is contained in:
parent
e2d47af072
commit
a00e73d4f0
6 changed files with 103 additions and 24 deletions
|
@ -119,15 +119,20 @@ def set_cfg(explaining_cfg):
|
||||||
explaining_cfg.threshold_config.relu_and_normalize = True
|
explaining_cfg.threshold_config.relu_and_normalize = True
|
||||||
|
|
||||||
# Select device: 'cpu', 'cuda', 'auto'
|
# Select device: 'cpu', 'cuda', 'auto'
|
||||||
explaining_cfg.accelerator = "auto"
|
|
||||||
|
|
||||||
# which objectives metrics to computes, either all or one in particular if implemented
|
# which objectives metrics to computes, either all or one in particular if implemented
|
||||||
explaining_cfg.metrics = CN()
|
explaining_cfg.metrics = CN()
|
||||||
explaining_cfg.metrics.type = "all"
|
explaining_cfg.metrics.name = "all"
|
||||||
|
|
||||||
# Whether or not recomputing metrics if they already exist
|
# Whether or not recomputing metrics if they already exist
|
||||||
explaining_cfg.metrics.force = False
|
explaining_cfg.metrics.force = False
|
||||||
|
|
||||||
|
explaining_cfg.attack = CN()
|
||||||
|
explaining_cfg.attack.name = 'all'
|
||||||
|
|
||||||
|
|
||||||
|
explaining_cfg.accelerator = "auto"
|
||||||
|
|
||||||
|
|
||||||
def assert_cfg(explaining_cfg):
|
def assert_cfg(explaining_cfg):
|
||||||
r"""Checks config values, do necessary post processing to the configs"""
|
r"""Checks config values, do necessary post processing to the configs"""
|
||||||
|
|
|
@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
|
||||||
import torch
|
import torch
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
|
||||||
|
from explaining_framework.utils.io import write_json
|
||||||
|
|
||||||
|
|
||||||
class Metric(ABC):
|
class Metric(ABC):
|
||||||
def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
|
def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
|
||||||
|
@ -46,3 +48,12 @@ class Metric(ABC):
|
||||||
self.model.train(training)
|
self.model.train(training)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def save_config(self, path) -> None:
|
||||||
|
config = {k: getattr(self, k) for k in dir(self)}
|
||||||
|
config = {
|
||||||
|
k: v
|
||||||
|
for k, v in config.items()
|
||||||
|
if isinstance(v, (int, float, str, bool)) or v is None
|
||||||
|
}
|
||||||
|
write_json(config, path)
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from explaining_framework.metric.base import Metric
|
|
||||||
from torch.nn import KLDivLoss, Softmax
|
from torch.nn import KLDivLoss, Softmax
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
|
||||||
print("NUM CLASSES cfg dataset")
|
from explaining_framework.metric.base import Metric
|
||||||
NUM_CLASS = 5
|
|
||||||
|
NUM_CLASS = cfg.share.dim_out
|
||||||
|
|
||||||
|
|
||||||
def softmax(data):
|
def softmax(data):
|
||||||
|
@ -26,6 +27,8 @@ class Fidelity(Metric):
|
||||||
"fidelity_plus_prob",
|
"fidelity_plus_prob",
|
||||||
"fidelity_minus_prob",
|
"fidelity_minus_prob",
|
||||||
"infidelity_KL",
|
"infidelity_KL",
|
||||||
|
"characterization",
|
||||||
|
"characterization_prob",
|
||||||
]
|
]
|
||||||
|
|
||||||
self.exp_sub = None
|
self.exp_sub = None
|
||||||
|
@ -99,9 +102,42 @@ class Fidelity(Metric):
|
||||||
self._score_check()
|
self._score_check()
|
||||||
prob_initial = softmax(self.s_initial_data)
|
prob_initial = softmax(self.s_initial_data)
|
||||||
prob_exp = F.log_softmax(self.s_exp_sub, dim=1)
|
prob_exp = F.log_softmax(self.s_exp_sub, dim=1)
|
||||||
print(prob_initial, prob_exp)
|
|
||||||
return (1 - torch.exp(-kl(prob_exp, prob_initial))).item()
|
return (1 - torch.exp(-kl(prob_exp, prob_initial))).item()
|
||||||
|
|
||||||
|
def _characterization_prob(
|
||||||
|
self,
|
||||||
|
exp: Explanation,
|
||||||
|
pos_weight: float = 0.5,
|
||||||
|
neg_weight: float = 0.5,
|
||||||
|
) -> Tensor:
|
||||||
|
if (pos_weight + neg_weight) != 1.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The weights need to sum up to 1 "
|
||||||
|
f"(got {pos_weight} and {neg_weight})"
|
||||||
|
)
|
||||||
|
pos_fidelity = self._fidelity_plus_prob(exp)
|
||||||
|
neg_fidelity = self._fidelity_minus_prob(exp)
|
||||||
|
|
||||||
|
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||||
|
return 1.0 / denom
|
||||||
|
|
||||||
|
def _characterization(
|
||||||
|
self,
|
||||||
|
exp: Explanation,
|
||||||
|
pos_weight: float = 0.5,
|
||||||
|
neg_weight: float = 0.5,
|
||||||
|
) -> Tensor:
|
||||||
|
if (pos_weight + neg_weight) != 1.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The weights need to sum up to 1 "
|
||||||
|
f"(got {pos_weight} and {neg_weight})"
|
||||||
|
)
|
||||||
|
pos_fidelity = self._fidelity_plus(exp)
|
||||||
|
neg_fidelity = self._fidelity_minus(exp)
|
||||||
|
|
||||||
|
denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
|
||||||
|
return 1.0 / denom
|
||||||
|
|
||||||
def score(self, exp):
|
def score(self, exp):
|
||||||
self.exp_sub = exp.get_explanation_subgraph()
|
self.exp_sub = exp.get_explanation_subgraph()
|
||||||
self.exp_sub_c = exp.get_complement_subgraph()
|
self.exp_sub_c = exp.get_complement_subgraph()
|
||||||
|
@ -125,6 +161,10 @@ class Fidelity(Metric):
|
||||||
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
self.metric = lambda exp: self._fidelity_minus_prob(exp)
|
||||||
if name == "infidelity_KL":
|
if name == "infidelity_KL":
|
||||||
self.metric = lambda exp: self._infidelity_KL(exp)
|
self.metric = lambda exp: self._infidelity_KL(exp)
|
||||||
|
if name == "characterization":
|
||||||
|
self.metric = lambda exp: self._characterization(exp)
|
||||||
|
if name == "characterization_prob":
|
||||||
|
self.metric = lambda exp: self._characterization_prob(exp)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{name} is not supported")
|
raise ValueError(f"{name} is not supported")
|
||||||
return self.metric
|
return self.metric
|
||||||
|
|
|
@ -18,7 +18,7 @@ def compute_gradient(model, inp, target, loss):
|
||||||
return torch.autograd.grad(err, inp.x)[0]
|
return torch.autograd.grad(err, inp.x)[0]
|
||||||
|
|
||||||
|
|
||||||
class FGSM(object):
|
class FGSM(Metric):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
|
@ -26,6 +26,7 @@ class FGSM(object):
|
||||||
lower_bound: float = float("-inf"),
|
lower_bound: float = float("-inf"),
|
||||||
upper_bound: float = float("inf"),
|
upper_bound: float = float("inf"),
|
||||||
):
|
):
|
||||||
|
super().__init__(name=name, model=model)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.lower_bound = lower_bound
|
self.lower_bound = lower_bound
|
||||||
|
@ -51,7 +52,7 @@ class FGSM(object):
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
|
||||||
class PGD(object):
|
class PGD(Metric):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
|
@ -59,6 +60,7 @@ class PGD(object):
|
||||||
lower_bound: float = float("-inf"),
|
lower_bound: float = float("-inf"),
|
||||||
upper_bound: float = float("inf"),
|
upper_bound: float = float("inf"),
|
||||||
):
|
):
|
||||||
|
super().__init__(name=name, model=model)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.lower_bound = lower_bound
|
self.lower_bound = lower_bound
|
||||||
|
|
|
@ -92,6 +92,8 @@ class ExplainingOutline(object):
|
||||||
self.model = None
|
self.model = None
|
||||||
self.dataset = None
|
self.dataset = None
|
||||||
self.model_info = None
|
self.model_info = None
|
||||||
|
self.metrics = None
|
||||||
|
self.attacks = None
|
||||||
|
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
self.load_model_info()
|
self.load_model_info()
|
||||||
|
@ -100,6 +102,8 @@ class ExplainingOutline(object):
|
||||||
self.load_model()
|
self.load_model()
|
||||||
self.load_explainer_cfg()
|
self.load_explainer_cfg()
|
||||||
self.load_explainer()
|
self.load_explainer()
|
||||||
|
self.load_metric()
|
||||||
|
self.load_attack()
|
||||||
|
|
||||||
def load_model_info(self):
|
def load_model_info(self):
|
||||||
info = LoadModelInfo(
|
info = LoadModelInfo(
|
||||||
|
@ -203,27 +207,36 @@ class ExplainingOutline(object):
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
|
|
||||||
if self.explaining_cfg.metrics.type == "all":
|
name_ = self.explaining_cfg.metrics.type
|
||||||
|
|
||||||
|
if name_ == "all":
|
||||||
|
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
|
||||||
|
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
||||||
|
self.metrics = all_acc_metrics + all_fid_metrics
|
||||||
|
|
||||||
if self.explaining_cfg.dataset.name == "BASHAPES":
|
if self.explaining_cfg.dataset.name == "BASHAPES":
|
||||||
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
||||||
all_fid_metrics = [Fidelity(name) for name in all_fidelity]
|
self.metrics = self.metrics + all_acc_metrics
|
||||||
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
elif name_ in all_fidelity:
|
||||||
|
self.metrics = [Fidelity(name_)]
|
||||||
|
elif name_ in all_sparsity:
|
||||||
|
self.metrics = [Sparsity(name_)]
|
||||||
|
elif name_ in all_accuracy:
|
||||||
|
if self.explaining_cfg.dataset.name == "BASHAPES":
|
||||||
|
self.metrics = [Accuracy(name_)]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation"
|
||||||
|
)
|
||||||
|
|
||||||
def load_attack(self):
|
def load_attack(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
self.load_cfg()
|
self.load_cfg()
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
all_rob_metrics = [Attack(name) for name in all_robust]
|
name_ = self.explaining_cfg.attack.name
|
||||||
|
if name_ == "all":
|
||||||
|
all_rob_metrics = [Attack(name) for name in all_robust]
|
||||||
class FileManager(object):
|
self.attacks = all_rob_metrics
|
||||||
def __init__(self):
|
if name_ in all_robust:
|
||||||
pass
|
self.attacks = [Attack(name_)]
|
||||||
|
|
||||||
def save(obj: Any, path: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
PATH = "config_exp.yaml"
|
|
||||||
test = ExplainingOutline(explaining_cfg_path=PATH)
|
|
||||||
|
|
8
main.py
8
main.py
|
@ -0,0 +1,8 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||||
|
from explaining_framework.utils.explaining.outline import parse_args
|
Loading…
Add table
Reference in a new issue