From 074ff25c839c80bcf9ddbcd36fe9bebfa03ea021 Mon Sep 17 00:00:00 2001 From: araison Date: Fri, 30 Dec 2022 19:34:41 +0100 Subject: [PATCH] Adding new features and first draft of main.py --- explaining_framework/metric/base.py | 9 +- .../utils/explaining/load_ckpt.py | 32 ++++- .../utils/explaining/outline.py | 20 ++- explaining_framework/utils/io.py | 25 +++- main.py | 134 +++++++++++++++++- 5 files changed, 203 insertions(+), 17 deletions(-) diff --git a/explaining_framework/metric/base.py b/explaining_framework/metric/base.py index aab33ae..5569ec6 100644 --- a/explaining_framework/metric/base.py +++ b/explaining_framework/metric/base.py @@ -49,11 +49,4 @@ class Metric(ABC): return out - def save_config(self, path) -> None: - config = {k: getattr(self, k) for k in dir(self)} - config = { - k: v - for k, v in config.items() - if isinstance(v, (int, float, str, bool)) or v is None - } - write_json(config, path) + diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index 0cc8640..288246d 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -7,11 +7,12 @@ import logging import os import torch -from explaining_framework.utils.io import read_yaml from torch_geometric.graphgym.model_builder import create_model from torch_geometric.graphgym.train import GraphGymDataModule from torch_geometric.graphgym.utils.io import json_to_dict_list +from explaining_framework.utils.io import read_yaml + MODEL_STATE = "model_state" OPTIMIZER_STATE = "optimizer_state" SCHEDULER_STATE = "scheduler_state" @@ -44,14 +45,15 @@ class LoadModelInfo(object): def list_stats(self, path) -> list: info = [] - for path in glob.glob( + for path_ in glob.glob( os.path.join(path, "[0-9]", self.wrt_metric, "stats.json") ): - stats = json_to_dict_list(path) + stats = json_to_dict_list(path_) for stat in stats: - xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path))) + xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path_))) + seed = int(os.path.basename(os.path.dirname(os.path.dirname(path_)))) ckpt_dir_path = os.path.join( - os.path.dirname(os.path.dirname(path)), "ckpt" + os.path.dirname(os.path.dirname(path_)), "ckpt" ) cfg_path = os.path.join(xp_dir_path, "config.yaml") epoch = stat["epoch"] @@ -68,12 +70,16 @@ class LoadModelInfo(object): epoch=epoch, ckpt_dir_path=ckpt_dir_path ), "cfg_path": cfg_path, + "seed": seed, "epoch": epoch, "accuracy": accuracy, "loss": loss, "lr": lr, "params": params, "time_iter": time_iter, + "which": self.which + if self.which in ["best", "worst"] + else None, } ) return info @@ -112,6 +118,22 @@ class LoadModelInfo(object): self.info = [item for item in stats if item["ckpt_path"] == self.which][0] return self.info + def get_model_signature(self): + if self.info is None: + self.set_info() + + model_name = os.path.basename(self.info["xp_dir_name"]) + model_seed = self.info["seed"] + epoch = os.path.basename(self.info["ckpt_path"]) + model_signature = "-".join( + [ + f"{name}={val}" + for name, val in zip(["name", "seed"], [model_name, model_seed]) + ] + + [epoch] + ) + return model_signature + def get_ckpt_path(self, epoch: int, ckpt_dir_path: str): paths = os.path.join(ckpt_dir_path, "*.ckpt") ckpts = [] diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 97787a8..d7ff67a 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -4,6 +4,7 @@ from typing import Any from eixgnn.eixgnn import EiXGNN from scgnn.scgnn import SCGNN from torch_geometric.data import Batch, Data +from torch_geometric.data.loader.dataloader import DataLoader from torch_geometric.explain import Explainer from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.loader import create_dataset @@ -62,6 +63,8 @@ all_fidelity = [ "fidelity_plus_prob", "fidelity_minus_prob", "infidelity_KL", + "characterization", + "characterization_prob", ] all_accuracy = [ "precision_score", @@ -94,6 +97,7 @@ class ExplainingOutline(object): self.model_info = None self.metrics = None self.attacks = None + self.model_signature = None self.load_explaining_cfg() self.load_model_info() @@ -112,6 +116,7 @@ class ExplainingOutline(object): which=self.explaining_cfg.model.ckpt, ) self.model_info = info.set_info() + self.model_signature = info.get_model_signature() def load_cfg(self): cfg.set_new_allowed(True) @@ -166,6 +171,7 @@ class ExplainingOutline(object): if isinstance(self.explaining_cfg.dataset.specific_items, int): ind = self.explaining_cfg.dataset.specific_items self.dataset = self.dataset[ind : ind + 1] + self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1) def load_explainer(self): self.load_explainer_cfg() @@ -199,6 +205,10 @@ class ExplainingOutline(object): interest_map_norm=self.explainer_cfg.interest_map_norm, score_map_norm=self.explainer_cfg.score_map_norm, ) + elif name is None: + explaining_algorithm = None + else: + raise ValueError(f"{name_} Metric is not supported yet") self.explaining_algorithm = explaining_algorithm def load_metric(self): @@ -228,6 +238,10 @@ class ExplainingOutline(object): raise ValueError( f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation" ) + elif name_ is None: + self.metrics = [] + else: + raise ValueError(f"{name_} Metric is not supported yet") def load_attack(self): if self.cfg is None: @@ -238,5 +252,9 @@ class ExplainingOutline(object): if name_ == "all": all_rob_metrics = [Attack(name) for name in all_robust] self.attacks = all_rob_metrics - if name_ in all_robust: + elif name_ in all_robust: self.attacks = [Attack(name_)] + elif name_ is None: + slef.attacks = [] + else: + raise ValueError(f"{name_} is an Attack method that is not supported yet") diff --git a/explaining_framework/utils/io.py b/explaining_framework/utils/io.py index b74c2e5..83f0960 100644 --- a/explaining_framework/utils/io.py +++ b/explaining_framework/utils/io.py @@ -1,5 +1,6 @@ import json import os + import yaml @@ -24,6 +25,26 @@ def write_yaml(data: dict, path: str) -> None: with open(path, "w") as f: data = yaml.dump(data, f) -def is_exists(path:str)-> bool: - return os.path.exists(path) +def is_exists(path: str) -> bool: + return os.path.exists(path) + + +def get_obj_config(obj): + config = {k: getattr(obj, k) for k in dir(obj)} + config = { + k: v + for k, v in config.items() + if isinstance(v, (int, float, str, bool)) or v is None + } + return config + + +def save_obj_config(obj, path) -> None: + config = get_obj_config(obj) + write_json(config, path) + + +def obj_config_to_str(obj) -> str: + config = get_obj_config(obj) + return "-".join([f"{k}={v}" for k, v in config.items()]) diff --git a/main.py b/main.py index 38ad4e8..834d8ac 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,138 @@ # import os +import time + +from torch_geometric import seed_everything +from torch_geometric.data.makedirs import makedirs +from torch_geometric.explain import Explainer +from torch_geometric.explain.config import ThresholdConfig +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.utils.device import auto_select_device + from explaining_framework.config.explaining_config import explaining_cfg from explaining_framework.utils.explaining.cmd_args import parse_args -from explaining_framework.utils.explaining.outline import parse_args +from explaining_framework.utils.explaining.outline import ExplainingOutline +from explaining_framework.utils.io import (obj_config_to_str, read_json, + write_json, write_yaml) +from explaining_framework.utils.explanation.adjust import Adjust + +# inference, time, force, + +def get_pred(explanation,force=False): + dict_ = explanation.to_dict() + if dict_.get('pred') is None or dict_.get('pred_masked') or force: + pred = explainer.get_prediction(explanation) + pred_masked = explainer.get_masked_prediction(x=explanation.x,edge_index=explanation.edge_index,node_mask=explanation.node_mask,edge_mask=explanation.edge_mask) + explanation.__setattr__('pred',pred) + explanation.__setattr__('pred_masked',pred_masked) + return explanation + else: + return explanation + + +if __name__ == "__main__": + args = parse_args() + outline = ExplainingOutline(args.explaining_cfg_file) + auto_select_device() + + # Load components + dataset = outline.dataset.to(cfg.accelerator) + model = outline.model.to(cfg.accelerator) + model_info = outline.model_info + metrics = outline.metrics + explaining_algorithm = outline.explaining_algorithm + attacks = outline.attacks + explainer_cfg = outline.explainer_cfg + model_signature = outline.model_signature + + # Set seed + seed_everything(explaining_cfg.seed) + + # Global path + global_path = os.path.join(explaining_cfg.out_dir, model_signature) + makedirs(global_path) + write_yaml(cfg, os.path.join(global_path, "config.yaml")) + write_json(model_info, os.path.join(global_path, "info.json")) + + global_path = os.path.join(global_path, explaining_cfg.explainer.name+'_'+obj_config_to_str(explaining_algorithm)) + makedirs(global_path) + write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) + write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) + + global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm)) + makedirs(global_path) + explainer = Explainer( + model=model, + algorithm=explaining_algorithm, + explainer_config=dict( + explanation_type=explaining_cfg.explanation_type, + node_mask_type="object", + edge_mask_type="object", + ), + model_config=dict( + mode="regression", + task_level=cfg.dataset.task, + return_type=explaining_cfg.model_config.return_type, + ), + ) + # Save explaining configuration + for index, item in enumerate(dataset): + save_raw_path = os.path.join(global_path, "raw") + makedirs(save_raw_path) + explanation_path = os.path.join(save_raw_path, f"{index}.json") + + if is_exists(explanation_path): + if explaining_cfg.explainer.force: + explanation = explainer( + x=item.x, + edge_index=item.edge_index, + index=item.y, + target=item.y, + ) + else: + explanation = load_explanation(explanation_path) + else: + explanation = explainer( + x=item.x, + edge_index=item.edge_index, + index=item.y, + target=item.y, + ) + explanation = get_pred(explanation,force=False) + save_explanation(explanation,explanation_path) + for apply_relu in [True,False]: + for apply_absolute in [True,False]: + adjust = Adjust(apply_relu=apply_relu,apply_absolute=apply_absolute) + save_raw_path = os.path.join(global_path,f'adjust-{obj_config_to_str(adjust)}') + makedirs(save_raw_path) + explanation = adjust.forward(explanation) + explanation_path = os.path.join(save_raw_path, f"{index}.json") + explanation = get_pred(explanation,force=True) + save_explanation(explanation,explanation_path) + + for threshold_approach in ['hard','topk','topk_hard']: + for threshold_value in explaining_cfg.threshold_config.value: + + masking_path =os.path.join(save_raw_path,f'threshold={threshold_approach}-value={value}') + exp_threshold_path = os.path.join(masking_path,f'{index}.json') + + if is_exists(exp_threshold_path): + explanation = load_explanation(exp_threshold_path) + else: + threshold_conf = {'threshold_type':threshold_approach,'value':threshold_value} + explainer.threshold_config = ThresholdConfig.cast(threshold_conf) + + expl = copy.copy(explanation) + exp_threshold = explainer._post_process(expl) + exp_threshold= get_pred(exp_threshold,force=True) + + save_explanation(exp_threshold,exp_threshold_path) + for metric in metrics: + metric_path =os.path.join(masking_path,f'{obj_config_to_str(metric)}') + if is_exists(os.path.join(metric_path,f'{index}.json')): + continue + else: + out = metric.forward(exp_threshold) + write_json({f'{metric.name}':out}) +