Fixing bugs and adding new features

This commit is contained in:
araison 2022-12-27 17:54:59 +01:00
parent e1b2a64bd6
commit 7c4f5ca196
8 changed files with 294 additions and 121 deletions

View File

@ -66,17 +66,18 @@ set_eixgnn_cfg(eixgnn_cfg)
# eixgnn_cfg.dump(stream=f)
#
#
# def load_eixgnn_cfg(eixgnn_cfg, args):
# r"""
# Load configurations from file system and command line
# Args:
# eixgnn_cfg (CfgNode): Configuration node
# args (ArgumentParser): Command argument parser
# """
# eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
# eixgnn_cfg.merge_from_list(args.opts)
# assert_eixgnn_cfg(eixgnn_cfg)
#
def load_eixgnn_cfg(eixgnn_cfg, args):
r"""
Load configurations from file system and command line
Args:
eixgnn_cfg (CfgNode): Configuration node
args (ArgumentParser): Command argument parser
"""
eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
eixgnn_cfg.merge_from_list(args.opts)
assert_eixgnn_cfg(eixgnn_cfg)
#
# def makedirs_rm_exist(dir):
# if os.path.isdir(dir):

View File

@ -36,27 +36,25 @@ def set_scgnn_cfg(scgnn_cfg):
if scgnn_cfg is None:
return scgnn_cfg
scgnn_cfg.depth = 'all'
scgnn_cfg.depth = "all"
scgnn_cfg.interest_map_norm = True
scgnn_cfg.score_map_norm = True
def assert_scgnn_cfg(scgnn_cfg):
def assert_cfg(scgnn_cfg):
r"""Checks config values, do necessary post processing to the configs
TODO
TODO
"""
if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]:
raise ValueError(
"Task {} not supported, must be one of node, "
"edge, graph, link_pred".format(scgnn_cfg.dataset.task)
)
scgnn_cfg.run_dir = scgnn_cfg.out_dir
"""
# if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]:
# raise ValueError(
# "Task {} not supported, must be one of node, "
# "edge, graph, link_pred".format(scgnn_cfg.dataset.task)
# )
# scgnn_cfg.run_dir = scgnn_cfg.out_dir
def dump_scgnn_cfg(scgnn_cfg,path):
def dump_cfg(scgnn_cfg, path):
r"""
TODO
Dumps the config to the output directory specified in
@ -65,14 +63,12 @@ def dump_scgnn_cfg(scgnn_cfg,path):
scgnn_cfg (CfgNode): Configuration node
"""
makedirs(scgnn_cfg.out_dir)
scgnn_cfg_file = os.path.join(
scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest
)
scgnn_cfg_file = os.path.join(scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest)
with open(scgnn_cfg_file, "w") as f:
scgnn_cfg.dump(stream=f)
def load_scgnn_cfg(scgnn_cfg, args):
def load_cfg(scgnn_cfg, args):
r"""
Load configurations from file system and command line
Args:
@ -116,8 +112,6 @@ def set_out_dir(out_dir, fname):
# Make output directory
if scgnn_cfg.train.auto_resume:
os.makedirs(scgnn_cfg.out_dir, exist_ok=True)
else:
makedirs_rm_exist(scgnn_cfg.out_dir)
def set_run_dir(out_dir):
@ -131,8 +125,6 @@ def set_run_dir(out_dir):
# Make output directory
if scgnn_cfg.train.auto_resume:
os.makedirs(scgnn_cfg.run_dir, exist_ok=True)
else:
makedirs_rm_exist(scgnn_cfg.run_dir)
set_scgnn_cfg(scgnn_cfg)

View File

@ -122,7 +122,8 @@ def set_cfg(explaining_cfg):
explaining_cfg.accelerator = "auto"
# which objectives metrics to computes, either all or one in particular if implemented
explaining_cfg.metrics = "all"
explaining_cfg.metrics = CN()
explaining_cfg.metrics.type = "all"
# Whether or not recomputing metrics if they already exist
explaining_cfg.metrics.force = False

View File

@ -86,7 +86,7 @@ class GraphXAIWrapper(ExplainerAlgorithm):
if criterion == "mse":
loss = MSELoss()
return loss
elif criterion == "cross-entropy":
elif criterion == "cross_entropy":
loss = CrossEntropyLoss()
return loss
else:

View File

@ -5,13 +5,6 @@ def parse_args() -> argparse.Namespace:
r"""Parses the command line arguments."""
parser = argparse.ArgumentParser(description="ExplainingFramework")
parser.add_argument(
"--cfg",
dest="cfg_file",
type=str,
required=True,
help="The configuration file path.",
)
parser.add_argument(
"--explaining_cfg",
dest="explaining_cfg_file",

View File

@ -1,14 +0,0 @@
class ExplainingOutline(object):
def __init__(self, explaining_cfg: CN, explainer_cfg: CN = None):
self.explaining_cfg = explaining_cfg
self.explainer_cfg = explainer_cfg
def load_cfg(self):
if self

View File

@ -0,0 +1,174 @@
import copy
from eixgnn.eixgnn import EiXGNN
from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
from explaining_framework.explainers.wrappers.from_graphxai import \
GraphXAIWrapper
from explaining_framework.metric.accuracy import Accuracy
from explaining_framework.metric.fidelity import Fidelity
from explaining_framework.metric.robust import Attack
from explaining_framework.metric.sparsity import Sparsity
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
_load_ckpt)
from scgnn.scgnn import SCGNN
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
all__captum = [
"LRP",
"DeepLift",
"DeepLiftShap",
"FeatureAblation",
"FeaturePermutation",
"GradientShap",
"GuidedBackprop",
"GuidedGradCam",
"InputXGradient",
"IntegratedGradients",
"Lime",
"Occlusion",
"Saliency",
]
all__graphxai = [
"CAM",
"GradCAM",
"GNN_LRP",
"GradExplainer",
"GuidedBackPropagation",
"IntegratedGradients",
"PGExplainer",
"PGMExplainer",
"RandomExplainer",
"SubgraphX",
"GraphMASK",
]
all__own = ["EIXGNN", "SCGNN"]
class ExplainingOutline(object):
def __init__(self, explaining_cfg_path: str):
self.explaining_cfg_path = explaining_cfg_path
self.explaining_cfg = None
self.explainer_cfg_path = None
self.explainer_cfg = None
self.explaining_algorithm = None
self.cfg = None
self.model = None
self.dataset = None
self.model_info = None
self.load_explaining_cfg()
self.load_model_info()
self.load_cfg()
self.load_dataset()
self.load_model()
self.load_explainer_cfg()
self.load_explainer()
def load_model_info(self):
info = LoadModelInfo(
dataset_name=self.explaining_cfg.dataset.name,
model_dir=self.explaining_cfg.model.path,
which=self.explaining_cfg.model.ckpt,
)
self.model_info = info.set_info()
def load_cfg(self):
cfg.set_new_allowed(True)
cfg.merge_from_file(self.model_info["cfg_path"])
self.cfg = cfg
def load_explaining_cfg(self):
explaining_cfg.set_new_allowed(True)
explaining_cfg.merge_from_file(self.explaining_cfg_path)
self.explaining_cfg = explaining_cfg
def load_explainer_cfg(self):
if self.explaining_cfg is None:
self.explaining_cfg()
else:
if self.explaining_cfg.explainer.cfg == "default":
if self.explaining_cfg.explainer.name == "EIXGNN":
self.explainer_cfg = copy.copy(eixgnn_cfg)
elif self.explaining_cfg.explainer.name == "SCGNN":
self.explainer_cfg = copy.copy(scgnn_cfg)
else:
self.explainer_cfg = None
else:
if self.explaining_cfg.explainer.name == "EIXGNN":
eixgnn_cfg.set_new_allowed(True)
eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
self.explainer_cfg = eixgnn_cfg
elif self.explaining_cfg.explainer.name == "SCGNN":
scgnn_cfg.set_new_allowed(True)
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
self.explainer_cfg = scgnn_cfg
def load_model(self):
if self.cfg is None:
self.load_cfg()
auto_select_device()
self.model = create_model()
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
if self.model is None:
raise ValueError("Model ckpt has not been loaded, ckpt file not found")
def load_dataset(self):
if self.cfg is None:
self.load_cfg()
if self.explaining_cfg is None:
self.explaining_cfg()
if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
raise ValueError(
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
)
self.dataset = create_dataset()
def load_explainer(self):
self.load_explainer_cfg()
if self.model is None:
self.load_model()
if self.dataset is None:
self.load_dataset()
name = self.explaining_cfg.explainer.name
if name in all__captum:
explaining_algorithm = CaptumWrapper(name)
elif name in all__graphxai:
explaining_algorithm = GraphXAIWrapper(
name,
in_channels=self.dataset.num_classes,
criterion=self.cfg.model.loss_fun,
)
elif name in all__own:
if name == "EIXGNN":
explaining_algorithm = EiXGNN(
L=self.explainer_cfg.L,
p=self.explainer_cfg.p,
importance_sampling_strategy=self.explainer_cfg.importance_sampling_strategy,
domain_similarity=self.explainer_cfg.domain_similarity,
signal_similarity=self.explainer_cfg.signal_similarity,
shap_val_approx=self.explainer_cfg.shapley_value_approx,
)
elif name == "SCGNN":
explaining_algorithm = SCGNN(
depth=self.explainer_cfg.depth,
interest_map_norm=self.explainer_cfg.interest_map_norm,
score_map_norm=self.explainer_cfg.score_map_norm,
)
self.explaining_algorithm = explaining_algorithm
print(self.explaining_algorithm.__dict__)
PATH = "config_exp.yaml"
test = ExplainingOutline(explaining_cfg_path=PATH)

View File

@ -23,78 +23,104 @@ def _load_ckpt(
) -> torch.nn.Module:
r"""Loads the model at given checkpoint."""
if not osp.exists(path):
if not os.path.exists(ckpt_path):
return None
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt[MODEL_STATE])
model.load_state_dict(ckpt["state_dict"])
return model
def xp_stats(path_to_xp: str, wrt_metric: str = "val") -> str:
acc = []
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
stats = json_to_dict_list(path)
for stat in stats:
acc.append(
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
class LoadModelInfo(object):
def __init__(
self, dataset_name: str, model_dir: str, which: str, wrt_metric: str = "val"
):
self.dataset_name = dataset_name
self.model_dir = model_dir
self.which = which
self.wrt_metric = wrt_metric
self.info = None
def list_stats(self, path) -> list:
info = []
for path in glob.glob(
os.path.join(path, "[0-9]", self.wrt_metric, "stats.json")
):
stats = json_to_dict_list(path)
for stat in stats:
xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path)))
ckpt_dir_path = os.path.join(
os.path.dirname(os.path.dirname(path)), "ckpt"
)
cfg_path = os.path.join(xp_dir_path, "config.yaml")
epoch = stat["epoch"]
accuracy = stat["accuracy"]
loss = stat["loss"]
lr = stat["lr"]
params = stat["params"]
time_iter = stat["time_iter"]
ckpt_path = self.get_ckpt_path(epoch=epoch, ckpt_dir_path=ckpt_dir_path)
info.append(
{
"xp_dir_path": xp_dir_path,
"ckpt_path": self.get_ckpt_path(
epoch=epoch, ckpt_dir_path=ckpt_dir_path
),
"cfg_path": cfg_path,
"epoch": epoch,
"accuracy": accuracy,
"loss": loss,
"lr": lr,
"params": params,
"time_iter": time_iter,
}
)
return info
def list_xp(self):
paths = []
for path in glob.glob(os.path.join(self.model_dir, "**", "config.yaml")):
file = self.load_cfg(path)
dataset_name_ = file["dataset"]["name"]
if self.dataset_name == dataset_name_:
paths.append(os.path.dirname(path))
return paths
def load_cfg(self, config_path):
return read_yaml(config_path)
def set_info(self):
if self.which in ["best", "worst"] or isinstance(self.which, int):
paths = self.list_xp()
infos = []
for path in paths:
info = self.list_stats(path)
infos.extend(info)
infos = sorted(infos, key=lambda item: item["accuracy"])
if self.which == "best":
self.info = infos[-1]
elif self.which == "worst":
self.info = infos[0]
elif isinstance(self.which, int):
self.info = infos[self.which]
else:
specific_xp_dir_path = os.path.dirname(
os.path.dirname(os.path.dirname(self.which))
)
return acc
stats = self.list_stats(specific_xp_dir_path)
self.info = [item for item in stats if item["ckpt_path"] == self.which][0]
return self.info
def xp_parser_dataset(dataset_name: str, models_dir_path) -> str:
paths = []
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
file = read_yaml(path)
dataset_name_ = file["dataset"]["name"]
if dataset_name == dataset_name_:
paths.append(os.path.dirname(path))
return paths
def best_xp_ckpt(paths, which: str = "best"):
acc = []
for path in paths:
accuracies = xp_stats(path)
acc.extend(accuracies)
acc = sorted(acc, key=lambda item: item["accuracy"])
if which == "best":
return acc[-1]
elif which == "worst":
return acc[0]
def stats_to_ckpt(parse):
paths = os.path.join(
os.path.dirname(os.path.dirname(parse["path"])), "ckpt", "*.ckpt"
)
ckpts = []
for path in glob.glob(paths):
feat = os.path.basename(path)
feat = feat.split("-")
for fe in feat:
if "epoch" in fe:
fe_ep = int(fe.split("=")[1])
ckpts.append({"path": path, "div": abs(fe_ep - parse["epoch"])})
return sorted(ckpts, key=lambda item: item["div"])[0]
def stats_to_cfg(parse):
path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(parse["path"])))
)
path = os.path.join(path, "config.yaml")
if os.path.exists(path):
return path
else:
raise FileNotFoundError(f"{path} does not exists")
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/"
best = best_xp_ckpt(
paths=xp_parser_dataset(dataset_name="CIFAR10", models_dir_path=PATH), which="worst"
)
print(best)
print(stats_to_ckpt(best))
print(stats_to_cfg(best))
def get_ckpt_path(self, epoch: int, ckpt_dir_path: str):
paths = os.path.join(ckpt_dir_path, "*.ckpt")
ckpts = []
for path in glob.glob(paths):
feat = os.path.basename(path)
feat = feat.split("-")
for fe in feat:
if "epoch" in fe:
fe_ep = int(fe.split("=")[1])
ckpts.append({"path": path, "div": abs(fe_ep - epoch)})
ckpt = sorted(ckpts, key=lambda item: item["div"])[0]
return ckpt["path"]