From 7c4f5ca1966b793af1b9814e5c4d6cb35819dc9e Mon Sep 17 00:00:00 2001 From: araison Date: Tue, 27 Dec 2022 17:54:59 +0100 Subject: [PATCH] Fixing bugs and adding new features --- .../config/explainer_config/eixgnn_config.py | 23 +-- .../config/explainer_config/scgnn_config.py | 34 ++-- .../config/explaining_config.py | 3 +- .../explainers/wrappers/from_graphxai.py | 2 +- .../utils/explaining/cmd_args.py | 7 - .../utils/explaining/explaining_exp.py | 14 -- .../utils/explaining/explaining_outline.py | 174 ++++++++++++++++++ .../utils/explaining/load_ckpt.py | 158 +++++++++------- 8 files changed, 294 insertions(+), 121 deletions(-) delete mode 100644 explaining_framework/utils/explaining/explaining_exp.py create mode 100644 explaining_framework/utils/explaining/explaining_outline.py diff --git a/explaining_framework/config/explainer_config/eixgnn_config.py b/explaining_framework/config/explainer_config/eixgnn_config.py index 0bd57b7..f36d3c0 100644 --- a/explaining_framework/config/explainer_config/eixgnn_config.py +++ b/explaining_framework/config/explainer_config/eixgnn_config.py @@ -66,17 +66,18 @@ set_eixgnn_cfg(eixgnn_cfg) # eixgnn_cfg.dump(stream=f) # # -# def load_eixgnn_cfg(eixgnn_cfg, args): -# r""" -# Load configurations from file system and command line -# Args: -# eixgnn_cfg (CfgNode): Configuration node -# args (ArgumentParser): Command argument parser -# """ -# eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file) -# eixgnn_cfg.merge_from_list(args.opts) -# assert_eixgnn_cfg(eixgnn_cfg) -# +def load_eixgnn_cfg(eixgnn_cfg, args): + r""" + Load configurations from file system and command line + Args: + eixgnn_cfg (CfgNode): Configuration node + args (ArgumentParser): Command argument parser + """ + eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file) + eixgnn_cfg.merge_from_list(args.opts) + assert_eixgnn_cfg(eixgnn_cfg) + + # # def makedirs_rm_exist(dir): # if os.path.isdir(dir): diff --git a/explaining_framework/config/explainer_config/scgnn_config.py b/explaining_framework/config/explainer_config/scgnn_config.py index adbb741..1766761 100644 --- a/explaining_framework/config/explainer_config/scgnn_config.py +++ b/explaining_framework/config/explainer_config/scgnn_config.py @@ -36,27 +36,25 @@ def set_scgnn_cfg(scgnn_cfg): if scgnn_cfg is None: return scgnn_cfg - scgnn_cfg.depth = 'all' + scgnn_cfg.depth = "all" scgnn_cfg.interest_map_norm = True scgnn_cfg.score_map_norm = True - - -def assert_scgnn_cfg(scgnn_cfg): +def assert_cfg(scgnn_cfg): r"""Checks config values, do necessary post processing to the configs - TODO + TODO - """ - if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]: - raise ValueError( - "Task {} not supported, must be one of node, " - "edge, graph, link_pred".format(scgnn_cfg.dataset.task) - ) - scgnn_cfg.run_dir = scgnn_cfg.out_dir + """ + # if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]: + # raise ValueError( + # "Task {} not supported, must be one of node, " + # "edge, graph, link_pred".format(scgnn_cfg.dataset.task) + # ) + # scgnn_cfg.run_dir = scgnn_cfg.out_dir -def dump_scgnn_cfg(scgnn_cfg,path): +def dump_cfg(scgnn_cfg, path): r""" TODO Dumps the config to the output directory specified in @@ -65,14 +63,12 @@ def dump_scgnn_cfg(scgnn_cfg,path): scgnn_cfg (CfgNode): Configuration node """ makedirs(scgnn_cfg.out_dir) - scgnn_cfg_file = os.path.join( - scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest - ) + scgnn_cfg_file = os.path.join(scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest) with open(scgnn_cfg_file, "w") as f: scgnn_cfg.dump(stream=f) -def load_scgnn_cfg(scgnn_cfg, args): +def load_cfg(scgnn_cfg, args): r""" Load configurations from file system and command line Args: @@ -116,8 +112,6 @@ def set_out_dir(out_dir, fname): # Make output directory if scgnn_cfg.train.auto_resume: os.makedirs(scgnn_cfg.out_dir, exist_ok=True) - else: - makedirs_rm_exist(scgnn_cfg.out_dir) def set_run_dir(out_dir): @@ -131,8 +125,6 @@ def set_run_dir(out_dir): # Make output directory if scgnn_cfg.train.auto_resume: os.makedirs(scgnn_cfg.run_dir, exist_ok=True) - else: - makedirs_rm_exist(scgnn_cfg.run_dir) set_scgnn_cfg(scgnn_cfg) diff --git a/explaining_framework/config/explaining_config.py b/explaining_framework/config/explaining_config.py index b14db35..ba2fc73 100644 --- a/explaining_framework/config/explaining_config.py +++ b/explaining_framework/config/explaining_config.py @@ -122,7 +122,8 @@ def set_cfg(explaining_cfg): explaining_cfg.accelerator = "auto" # which objectives metrics to computes, either all or one in particular if implemented - explaining_cfg.metrics = "all" + explaining_cfg.metrics = CN() + explaining_cfg.metrics.type = "all" # Whether or not recomputing metrics if they already exist explaining_cfg.metrics.force = False diff --git a/explaining_framework/explainers/wrappers/from_graphxai.py b/explaining_framework/explainers/wrappers/from_graphxai.py index 2ac92ed..04ac2fa 100644 --- a/explaining_framework/explainers/wrappers/from_graphxai.py +++ b/explaining_framework/explainers/wrappers/from_graphxai.py @@ -86,7 +86,7 @@ class GraphXAIWrapper(ExplainerAlgorithm): if criterion == "mse": loss = MSELoss() return loss - elif criterion == "cross-entropy": + elif criterion == "cross_entropy": loss = CrossEntropyLoss() return loss else: diff --git a/explaining_framework/utils/explaining/cmd_args.py b/explaining_framework/utils/explaining/cmd_args.py index 23840fc..388d231 100644 --- a/explaining_framework/utils/explaining/cmd_args.py +++ b/explaining_framework/utils/explaining/cmd_args.py @@ -5,13 +5,6 @@ def parse_args() -> argparse.Namespace: r"""Parses the command line arguments.""" parser = argparse.ArgumentParser(description="ExplainingFramework") - parser.add_argument( - "--cfg", - dest="cfg_file", - type=str, - required=True, - help="The configuration file path.", - ) parser.add_argument( "--explaining_cfg", dest="explaining_cfg_file", diff --git a/explaining_framework/utils/explaining/explaining_exp.py b/explaining_framework/utils/explaining/explaining_exp.py deleted file mode 100644 index c01a5db..0000000 --- a/explaining_framework/utils/explaining/explaining_exp.py +++ /dev/null @@ -1,14 +0,0 @@ -class ExplainingOutline(object): - def __init__(self, explaining_cfg: CN, explainer_cfg: CN = None): - self.explaining_cfg = explaining_cfg - self.explainer_cfg = explainer_cfg - - - - def load_cfg(self): - if self - - - - - diff --git a/explaining_framework/utils/explaining/explaining_outline.py b/explaining_framework/utils/explaining/explaining_outline.py new file mode 100644 index 0000000..b8459c4 --- /dev/null +++ b/explaining_framework/utils/explaining/explaining_outline.py @@ -0,0 +1,174 @@ +import copy + +from eixgnn.eixgnn import EiXGNN +from explaining_framework.config.explainer_config.eixgnn_config import \ + eixgnn_cfg +from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg +from explaining_framework.config.explaining_config import explaining_cfg +from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper +from explaining_framework.explainers.wrappers.from_graphxai import \ + GraphXAIWrapper +from explaining_framework.metric.accuracy import Accuracy +from explaining_framework.metric.fidelity import Fidelity +from explaining_framework.metric.robust import Attack +from explaining_framework.metric.sparsity import Sparsity +from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo, + _load_ckpt) +from scgnn.scgnn import SCGNN +from torch_geometric.data import Batch, Data +from torch_geometric.explain import Explainer +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device + +all__captum = [ + "LRP", + "DeepLift", + "DeepLiftShap", + "FeatureAblation", + "FeaturePermutation", + "GradientShap", + "GuidedBackprop", + "GuidedGradCam", + "InputXGradient", + "IntegratedGradients", + "Lime", + "Occlusion", + "Saliency", +] + +all__graphxai = [ + "CAM", + "GradCAM", + "GNN_LRP", + "GradExplainer", + "GuidedBackPropagation", + "IntegratedGradients", + "PGExplainer", + "PGMExplainer", + "RandomExplainer", + "SubgraphX", + "GraphMASK", +] + +all__own = ["EIXGNN", "SCGNN"] + + +class ExplainingOutline(object): + def __init__(self, explaining_cfg_path: str): + self.explaining_cfg_path = explaining_cfg_path + self.explaining_cfg = None + self.explainer_cfg_path = None + self.explainer_cfg = None + self.explaining_algorithm = None + self.cfg = None + self.model = None + self.dataset = None + self.model_info = None + + self.load_explaining_cfg() + self.load_model_info() + self.load_cfg() + self.load_dataset() + self.load_model() + self.load_explainer_cfg() + self.load_explainer() + + def load_model_info(self): + info = LoadModelInfo( + dataset_name=self.explaining_cfg.dataset.name, + model_dir=self.explaining_cfg.model.path, + which=self.explaining_cfg.model.ckpt, + ) + self.model_info = info.set_info() + + def load_cfg(self): + cfg.set_new_allowed(True) + cfg.merge_from_file(self.model_info["cfg_path"]) + self.cfg = cfg + + def load_explaining_cfg(self): + explaining_cfg.set_new_allowed(True) + explaining_cfg.merge_from_file(self.explaining_cfg_path) + self.explaining_cfg = explaining_cfg + + def load_explainer_cfg(self): + if self.explaining_cfg is None: + self.explaining_cfg() + else: + if self.explaining_cfg.explainer.cfg == "default": + if self.explaining_cfg.explainer.name == "EIXGNN": + self.explainer_cfg = copy.copy(eixgnn_cfg) + elif self.explaining_cfg.explainer.name == "SCGNN": + self.explainer_cfg = copy.copy(scgnn_cfg) + else: + self.explainer_cfg = None + else: + if self.explaining_cfg.explainer.name == "EIXGNN": + eixgnn_cfg.set_new_allowed(True) + eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg) + self.explainer_cfg = eixgnn_cfg + elif self.explaining_cfg.explainer.name == "SCGNN": + scgnn_cfg.set_new_allowed(True) + scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg) + self.explainer_cfg = scgnn_cfg + + def load_model(self): + if self.cfg is None: + self.load_cfg() + auto_select_device() + self.model = create_model() + self.model = _load_ckpt(self.model, self.model_info["ckpt_path"]) + if self.model is None: + raise ValueError("Model ckpt has not been loaded, ckpt file not found") + + def load_dataset(self): + if self.cfg is None: + self.load_cfg() + if self.explaining_cfg is None: + self.explaining_cfg() + if self.explaining_cfg.dataset.name != self.cfg.dataset.name: + raise ValueError( + f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model." + ) + self.dataset = create_dataset() + + def load_explainer(self): + self.load_explainer_cfg() + if self.model is None: + self.load_model() + if self.dataset is None: + self.load_dataset() + + name = self.explaining_cfg.explainer.name + if name in all__captum: + explaining_algorithm = CaptumWrapper(name) + elif name in all__graphxai: + explaining_algorithm = GraphXAIWrapper( + name, + in_channels=self.dataset.num_classes, + criterion=self.cfg.model.loss_fun, + ) + elif name in all__own: + if name == "EIXGNN": + explaining_algorithm = EiXGNN( + L=self.explainer_cfg.L, + p=self.explainer_cfg.p, + importance_sampling_strategy=self.explainer_cfg.importance_sampling_strategy, + domain_similarity=self.explainer_cfg.domain_similarity, + signal_similarity=self.explainer_cfg.signal_similarity, + shap_val_approx=self.explainer_cfg.shapley_value_approx, + ) + elif name == "SCGNN": + explaining_algorithm = SCGNN( + depth=self.explainer_cfg.depth, + interest_map_norm=self.explainer_cfg.interest_map_norm, + score_map_norm=self.explainer_cfg.score_map_norm, + ) + self.explaining_algorithm = explaining_algorithm + print(self.explaining_algorithm.__dict__) + + +PATH = "config_exp.yaml" +test = ExplainingOutline(explaining_cfg_path=PATH) diff --git a/explaining_framework/utils/explaining/load_ckpt.py b/explaining_framework/utils/explaining/load_ckpt.py index a17c3f1..0cc8640 100644 --- a/explaining_framework/utils/explaining/load_ckpt.py +++ b/explaining_framework/utils/explaining/load_ckpt.py @@ -23,78 +23,104 @@ def _load_ckpt( ) -> torch.nn.Module: r"""Loads the model at given checkpoint.""" - if not osp.exists(path): + if not os.path.exists(ckpt_path): return None ckpt = torch.load(ckpt_path) - model.load_state_dict(ckpt[MODEL_STATE]) + model.load_state_dict(ckpt["state_dict"]) return model -def xp_stats(path_to_xp: str, wrt_metric: str = "val") -> str: - acc = [] - for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")): - stats = json_to_dict_list(path) - for stat in stats: - acc.append( - {"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]} +class LoadModelInfo(object): + def __init__( + self, dataset_name: str, model_dir: str, which: str, wrt_metric: str = "val" + ): + self.dataset_name = dataset_name + self.model_dir = model_dir + self.which = which + self.wrt_metric = wrt_metric + self.info = None + + def list_stats(self, path) -> list: + info = [] + for path in glob.glob( + os.path.join(path, "[0-9]", self.wrt_metric, "stats.json") + ): + stats = json_to_dict_list(path) + for stat in stats: + xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path))) + ckpt_dir_path = os.path.join( + os.path.dirname(os.path.dirname(path)), "ckpt" + ) + cfg_path = os.path.join(xp_dir_path, "config.yaml") + epoch = stat["epoch"] + accuracy = stat["accuracy"] + loss = stat["loss"] + lr = stat["lr"] + params = stat["params"] + time_iter = stat["time_iter"] + ckpt_path = self.get_ckpt_path(epoch=epoch, ckpt_dir_path=ckpt_dir_path) + info.append( + { + "xp_dir_path": xp_dir_path, + "ckpt_path": self.get_ckpt_path( + epoch=epoch, ckpt_dir_path=ckpt_dir_path + ), + "cfg_path": cfg_path, + "epoch": epoch, + "accuracy": accuracy, + "loss": loss, + "lr": lr, + "params": params, + "time_iter": time_iter, + } + ) + return info + + def list_xp(self): + paths = [] + for path in glob.glob(os.path.join(self.model_dir, "**", "config.yaml")): + file = self.load_cfg(path) + dataset_name_ = file["dataset"]["name"] + if self.dataset_name == dataset_name_: + paths.append(os.path.dirname(path)) + return paths + + def load_cfg(self, config_path): + return read_yaml(config_path) + + def set_info(self): + if self.which in ["best", "worst"] or isinstance(self.which, int): + paths = self.list_xp() + infos = [] + for path in paths: + info = self.list_stats(path) + infos.extend(info) + infos = sorted(infos, key=lambda item: item["accuracy"]) + if self.which == "best": + self.info = infos[-1] + elif self.which == "worst": + self.info = infos[0] + elif isinstance(self.which, int): + self.info = infos[self.which] + else: + specific_xp_dir_path = os.path.dirname( + os.path.dirname(os.path.dirname(self.which)) ) - return acc + stats = self.list_stats(specific_xp_dir_path) + self.info = [item for item in stats if item["ckpt_path"] == self.which][0] + return self.info - -def xp_parser_dataset(dataset_name: str, models_dir_path) -> str: - paths = [] - for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")): - file = read_yaml(path) - dataset_name_ = file["dataset"]["name"] - if dataset_name == dataset_name_: - paths.append(os.path.dirname(path)) - return paths - - -def best_xp_ckpt(paths, which: str = "best"): - acc = [] - for path in paths: - accuracies = xp_stats(path) - acc.extend(accuracies) - acc = sorted(acc, key=lambda item: item["accuracy"]) - if which == "best": - return acc[-1] - elif which == "worst": - return acc[0] - - -def stats_to_ckpt(parse): - paths = os.path.join( - os.path.dirname(os.path.dirname(parse["path"])), "ckpt", "*.ckpt" - ) - ckpts = [] - for path in glob.glob(paths): - feat = os.path.basename(path) - feat = feat.split("-") - for fe in feat: - if "epoch" in fe: - fe_ep = int(fe.split("=")[1]) - ckpts.append({"path": path, "div": abs(fe_ep - parse["epoch"])}) - return sorted(ckpts, key=lambda item: item["div"])[0] - - -def stats_to_cfg(parse): - path = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(parse["path"]))) - ) - path = os.path.join(path, "config.yaml") - if os.path.exists(path): - return path - else: - raise FileNotFoundError(f"{path} does not exists") - - -PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/" -best = best_xp_ckpt( - paths=xp_parser_dataset(dataset_name="CIFAR10", models_dir_path=PATH), which="worst" -) -print(best) -print(stats_to_ckpt(best)) -print(stats_to_cfg(best)) + def get_ckpt_path(self, epoch: int, ckpt_dir_path: str): + paths = os.path.join(ckpt_dir_path, "*.ckpt") + ckpts = [] + for path in glob.glob(paths): + feat = os.path.basename(path) + feat = feat.split("-") + for fe in feat: + if "epoch" in fe: + fe_ep = int(fe.split("=")[1]) + ckpts.append({"path": path, "div": abs(fe_ep - epoch)}) + ckpt = sorted(ckpts, key=lambda item: item["div"])[0] + return ckpt["path"]