Adding new features
This commit is contained in:
		
							parent
							
								
									e2d47af072
								
							
						
					
					
						commit
						a00e73d4f0
					
				
					 6 changed files with 103 additions and 24 deletions
				
			
		| 
						 | 
				
			
			@ -119,15 +119,20 @@ def set_cfg(explaining_cfg):
 | 
			
		|||
    explaining_cfg.threshold_config.relu_and_normalize = True
 | 
			
		||||
 | 
			
		||||
    # Select device: 'cpu', 'cuda', 'auto'
 | 
			
		||||
    explaining_cfg.accelerator = "auto"
 | 
			
		||||
 | 
			
		||||
    # which objectives metrics to computes, either all or one in particular if implemented
 | 
			
		||||
    explaining_cfg.metrics = CN()
 | 
			
		||||
    explaining_cfg.metrics.type = "all"
 | 
			
		||||
    explaining_cfg.metrics.name = "all"
 | 
			
		||||
 | 
			
		||||
    # Whether or not recomputing metrics if they already exist
 | 
			
		||||
    explaining_cfg.metrics.force = False
 | 
			
		||||
 | 
			
		||||
    explaining_cfg.attack = CN()
 | 
			
		||||
    explaining_cfg.attack.name = 'all'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    explaining_cfg.accelerator = "auto"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def assert_cfg(explaining_cfg):
 | 
			
		||||
    r"""Checks config values, do necessary post processing to the configs"""
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
 | 
			
		|||
import torch
 | 
			
		||||
from torch_geometric.explain.explanation import Explanation
 | 
			
		||||
 | 
			
		||||
from explaining_framework.utils.io import write_json
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Metric(ABC):
 | 
			
		||||
    def __init__(self, name: str, model: torch.nn.Module = None, **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			@ -46,3 +48,12 @@ class Metric(ABC):
 | 
			
		|||
        self.model.train(training)
 | 
			
		||||
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def save_config(self, path) -> None:
 | 
			
		||||
        config = {k: getattr(self, k) for k in dir(self)}
 | 
			
		||||
        config = {
 | 
			
		||||
            k: v
 | 
			
		||||
            for k, v in config.items()
 | 
			
		||||
            if isinstance(v, (int, float, str, bool)) or v is None
 | 
			
		||||
        }
 | 
			
		||||
        write_json(config, path)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,11 +1,12 @@
 | 
			
		|||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from explaining_framework.metric.base import Metric
 | 
			
		||||
from torch.nn import KLDivLoss, Softmax
 | 
			
		||||
from torch_geometric.explain.explanation import Explanation
 | 
			
		||||
from torch_geometric.graphgym.config import cfg
 | 
			
		||||
 | 
			
		||||
print("NUM CLASSES cfg dataset")
 | 
			
		||||
NUM_CLASS = 5
 | 
			
		||||
from explaining_framework.metric.base import Metric
 | 
			
		||||
 | 
			
		||||
NUM_CLASS = cfg.share.dim_out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def softmax(data):
 | 
			
		||||
| 
						 | 
				
			
			@ -26,6 +27,8 @@ class Fidelity(Metric):
 | 
			
		|||
            "fidelity_plus_prob",
 | 
			
		||||
            "fidelity_minus_prob",
 | 
			
		||||
            "infidelity_KL",
 | 
			
		||||
            "characterization",
 | 
			
		||||
            "characterization_prob",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        self.exp_sub = None
 | 
			
		||||
| 
						 | 
				
			
			@ -99,9 +102,42 @@ class Fidelity(Metric):
 | 
			
		|||
        self._score_check()
 | 
			
		||||
        prob_initial = softmax(self.s_initial_data)
 | 
			
		||||
        prob_exp = F.log_softmax(self.s_exp_sub, dim=1)
 | 
			
		||||
        print(prob_initial, prob_exp)
 | 
			
		||||
        return (1 - torch.exp(-kl(prob_exp, prob_initial))).item()
 | 
			
		||||
 | 
			
		||||
    def _characterization_prob(
 | 
			
		||||
        self,
 | 
			
		||||
        exp: Explanation,
 | 
			
		||||
        pos_weight: float = 0.5,
 | 
			
		||||
        neg_weight: float = 0.5,
 | 
			
		||||
    ) -> Tensor:
 | 
			
		||||
        if (pos_weight + neg_weight) != 1.0:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"The weights need to sum up to 1 "
 | 
			
		||||
                f"(got {pos_weight} and {neg_weight})"
 | 
			
		||||
            )
 | 
			
		||||
        pos_fidelity = self._fidelity_plus_prob(exp)
 | 
			
		||||
        neg_fidelity = self._fidelity_minus_prob(exp)
 | 
			
		||||
 | 
			
		||||
        denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
 | 
			
		||||
        return 1.0 / denom
 | 
			
		||||
 | 
			
		||||
    def _characterization(
 | 
			
		||||
        self,
 | 
			
		||||
        exp: Explanation,
 | 
			
		||||
        pos_weight: float = 0.5,
 | 
			
		||||
        neg_weight: float = 0.5,
 | 
			
		||||
    ) -> Tensor:
 | 
			
		||||
        if (pos_weight + neg_weight) != 1.0:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"The weights need to sum up to 1 "
 | 
			
		||||
                f"(got {pos_weight} and {neg_weight})"
 | 
			
		||||
            )
 | 
			
		||||
        pos_fidelity = self._fidelity_plus(exp)
 | 
			
		||||
        neg_fidelity = self._fidelity_minus(exp)
 | 
			
		||||
 | 
			
		||||
        denom = (pos_weight / pos_fidelity) + (neg_weight / (1.0 - neg_fidelity))
 | 
			
		||||
        return 1.0 / denom
 | 
			
		||||
 | 
			
		||||
    def score(self, exp):
 | 
			
		||||
        self.exp_sub = exp.get_explanation_subgraph()
 | 
			
		||||
        self.exp_sub_c = exp.get_complement_subgraph()
 | 
			
		||||
| 
						 | 
				
			
			@ -125,6 +161,10 @@ class Fidelity(Metric):
 | 
			
		|||
                self.metric = lambda exp: self._fidelity_minus_prob(exp)
 | 
			
		||||
            if name == "infidelity_KL":
 | 
			
		||||
                self.metric = lambda exp: self._infidelity_KL(exp)
 | 
			
		||||
            if name == "characterization":
 | 
			
		||||
                self.metric = lambda exp: self._characterization(exp)
 | 
			
		||||
            if name == "characterization_prob":
 | 
			
		||||
                self.metric = lambda exp: self._characterization_prob(exp)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"{name} is not supported")
 | 
			
		||||
        return self.metric
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,7 @@ def compute_gradient(model, inp, target, loss):
 | 
			
		|||
        return torch.autograd.grad(err, inp.x)[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FGSM(object):
 | 
			
		||||
class FGSM(Metric):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        model: torch.nn.Module,
 | 
			
		||||
| 
						 | 
				
			
			@ -26,6 +26,7 @@ class FGSM(object):
 | 
			
		|||
        lower_bound: float = float("-inf"),
 | 
			
		||||
        upper_bound: float = float("inf"),
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(name=name, model=model)
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.loss = loss
 | 
			
		||||
        self.lower_bound = lower_bound
 | 
			
		||||
| 
						 | 
				
			
			@ -51,7 +52,7 @@ class FGSM(object):
 | 
			
		|||
        return input_
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PGD(object):
 | 
			
		||||
class PGD(Metric):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        model: torch.nn.Module,
 | 
			
		||||
| 
						 | 
				
			
			@ -59,6 +60,7 @@ class PGD(object):
 | 
			
		|||
        lower_bound: float = float("-inf"),
 | 
			
		||||
        upper_bound: float = float("inf"),
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(name=name, model=model)
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.loss = loss
 | 
			
		||||
        self.lower_bound = lower_bound
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -92,6 +92,8 @@ class ExplainingOutline(object):
 | 
			
		|||
        self.model = None
 | 
			
		||||
        self.dataset = None
 | 
			
		||||
        self.model_info = None
 | 
			
		||||
        self.metrics = None
 | 
			
		||||
        self.attacks = None
 | 
			
		||||
 | 
			
		||||
        self.load_explaining_cfg()
 | 
			
		||||
        self.load_model_info()
 | 
			
		||||
| 
						 | 
				
			
			@ -100,6 +102,8 @@ class ExplainingOutline(object):
 | 
			
		|||
        self.load_model()
 | 
			
		||||
        self.load_explainer_cfg()
 | 
			
		||||
        self.load_explainer()
 | 
			
		||||
        self.load_metric()
 | 
			
		||||
        self.load_attack()
 | 
			
		||||
 | 
			
		||||
    def load_model_info(self):
 | 
			
		||||
        info = LoadModelInfo(
 | 
			
		||||
| 
						 | 
				
			
			@ -203,27 +207,36 @@ class ExplainingOutline(object):
 | 
			
		|||
        if self.explaining_cfg is None:
 | 
			
		||||
            self.load_explaining_cfg()
 | 
			
		||||
 | 
			
		||||
        if self.explaining_cfg.metrics.type == "all":
 | 
			
		||||
            if self.explaining_cfg.dataset.name == "BASHAPES":
 | 
			
		||||
                all_acc_metrics = [Accuracy(name) for name in all_accuracy]
 | 
			
		||||
        name_ = self.explaining_cfg.metrics.type
 | 
			
		||||
 | 
			
		||||
        if name_ == "all":
 | 
			
		||||
            all_fid_metrics = [Fidelity(name) for name in all_fidelity]
 | 
			
		||||
            all_spa_metrics = [Sparsity(name) for name in all_sparsity]
 | 
			
		||||
            self.metrics = all_acc_metrics + all_fid_metrics
 | 
			
		||||
 | 
			
		||||
            if self.explaining_cfg.dataset.name == "BASHAPES":
 | 
			
		||||
                all_acc_metrics = [Accuracy(name) for name in all_accuracy]
 | 
			
		||||
                self.metrics = self.metrics + all_acc_metrics
 | 
			
		||||
        elif name_ in all_fidelity:
 | 
			
		||||
            self.metrics = [Fidelity(name_)]
 | 
			
		||||
        elif name_ in all_sparsity:
 | 
			
		||||
            self.metrics = [Sparsity(name_)]
 | 
			
		||||
        elif name_ in all_accuracy:
 | 
			
		||||
            if self.explaining_cfg.dataset.name == "BASHAPES":
 | 
			
		||||
                self.metrics = [Accuracy(name_)]
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def load_attack(self):
 | 
			
		||||
        if self.cfg is None:
 | 
			
		||||
            self.load_cfg()
 | 
			
		||||
        if self.explaining_cfg is None:
 | 
			
		||||
            self.load_explaining_cfg()
 | 
			
		||||
        name_ = self.explaining_cfg.attack.name
 | 
			
		||||
        if name_ == "all":
 | 
			
		||||
            all_rob_metrics = [Attack(name) for name in all_robust]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FileManager(object):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def save(obj: Any, path: str) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PATH = "config_exp.yaml"
 | 
			
		||||
test = ExplainingOutline(explaining_cfg_path=PATH)
 | 
			
		||||
            self.attacks = all_rob_metrics
 | 
			
		||||
        if name_ in all_robust:
 | 
			
		||||
            self.attacks = [Attack(name_)]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										8
									
								
								main.py
									
										
									
									
									
								
							
							
						
						
									
										8
									
								
								main.py
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
#!/usr/bin/env python
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from explaining_framework.config.explaining_config import explaining_cfg
 | 
			
		||||
from explaining_framework.utils.explaining.cmd_args import parse_args
 | 
			
		||||
from explaining_framework.utils.explaining.outline import parse_args
 | 
			
		||||
		Loading…
	
	Add table
		
		Reference in a new issue