Fixing bugs and adding new features
This commit is contained in:
parent
e1b2a64bd6
commit
7c4f5ca196
|
@ -66,17 +66,18 @@ set_eixgnn_cfg(eixgnn_cfg)
|
|||
# eixgnn_cfg.dump(stream=f)
|
||||
#
|
||||
#
|
||||
# def load_eixgnn_cfg(eixgnn_cfg, args):
|
||||
# r"""
|
||||
# Load configurations from file system and command line
|
||||
# Args:
|
||||
# eixgnn_cfg (CfgNode): Configuration node
|
||||
# args (ArgumentParser): Command argument parser
|
||||
# """
|
||||
# eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
|
||||
# eixgnn_cfg.merge_from_list(args.opts)
|
||||
# assert_eixgnn_cfg(eixgnn_cfg)
|
||||
#
|
||||
def load_eixgnn_cfg(eixgnn_cfg, args):
|
||||
r"""
|
||||
Load configurations from file system and command line
|
||||
Args:
|
||||
eixgnn_cfg (CfgNode): Configuration node
|
||||
args (ArgumentParser): Command argument parser
|
||||
"""
|
||||
eixgnn_cfg.merge_from_file(args.eixgnn_cfg_file)
|
||||
eixgnn_cfg.merge_from_list(args.opts)
|
||||
assert_eixgnn_cfg(eixgnn_cfg)
|
||||
|
||||
|
||||
#
|
||||
# def makedirs_rm_exist(dir):
|
||||
# if os.path.isdir(dir):
|
||||
|
|
|
@ -36,27 +36,25 @@ def set_scgnn_cfg(scgnn_cfg):
|
|||
if scgnn_cfg is None:
|
||||
return scgnn_cfg
|
||||
|
||||
scgnn_cfg.depth = 'all'
|
||||
scgnn_cfg.depth = "all"
|
||||
scgnn_cfg.interest_map_norm = True
|
||||
scgnn_cfg.score_map_norm = True
|
||||
|
||||
|
||||
|
||||
|
||||
def assert_scgnn_cfg(scgnn_cfg):
|
||||
def assert_cfg(scgnn_cfg):
|
||||
r"""Checks config values, do necessary post processing to the configs
|
||||
TODO
|
||||
TODO
|
||||
|
||||
"""
|
||||
if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]:
|
||||
raise ValueError(
|
||||
"Task {} not supported, must be one of node, "
|
||||
"edge, graph, link_pred".format(scgnn_cfg.dataset.task)
|
||||
)
|
||||
scgnn_cfg.run_dir = scgnn_cfg.out_dir
|
||||
"""
|
||||
# if scgnn_cfg. not in ["node", "edge", "graph", "link_pred"]:
|
||||
# raise ValueError(
|
||||
# "Task {} not supported, must be one of node, "
|
||||
# "edge, graph, link_pred".format(scgnn_cfg.dataset.task)
|
||||
# )
|
||||
# scgnn_cfg.run_dir = scgnn_cfg.out_dir
|
||||
|
||||
|
||||
def dump_scgnn_cfg(scgnn_cfg,path):
|
||||
def dump_cfg(scgnn_cfg, path):
|
||||
r"""
|
||||
TODO
|
||||
Dumps the config to the output directory specified in
|
||||
|
@ -65,14 +63,12 @@ def dump_scgnn_cfg(scgnn_cfg,path):
|
|||
scgnn_cfg (CfgNode): Configuration node
|
||||
"""
|
||||
makedirs(scgnn_cfg.out_dir)
|
||||
scgnn_cfg_file = os.path.join(
|
||||
scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest
|
||||
)
|
||||
scgnn_cfg_file = os.path.join(scgnn_cfg.out_dir, scgnn_cfg.scgnn_cfg_dest)
|
||||
with open(scgnn_cfg_file, "w") as f:
|
||||
scgnn_cfg.dump(stream=f)
|
||||
|
||||
|
||||
def load_scgnn_cfg(scgnn_cfg, args):
|
||||
def load_cfg(scgnn_cfg, args):
|
||||
r"""
|
||||
Load configurations from file system and command line
|
||||
Args:
|
||||
|
@ -116,8 +112,6 @@ def set_out_dir(out_dir, fname):
|
|||
# Make output directory
|
||||
if scgnn_cfg.train.auto_resume:
|
||||
os.makedirs(scgnn_cfg.out_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(scgnn_cfg.out_dir)
|
||||
|
||||
|
||||
def set_run_dir(out_dir):
|
||||
|
@ -131,8 +125,6 @@ def set_run_dir(out_dir):
|
|||
# Make output directory
|
||||
if scgnn_cfg.train.auto_resume:
|
||||
os.makedirs(scgnn_cfg.run_dir, exist_ok=True)
|
||||
else:
|
||||
makedirs_rm_exist(scgnn_cfg.run_dir)
|
||||
|
||||
|
||||
set_scgnn_cfg(scgnn_cfg)
|
||||
|
|
|
@ -122,7 +122,8 @@ def set_cfg(explaining_cfg):
|
|||
explaining_cfg.accelerator = "auto"
|
||||
|
||||
# which objectives metrics to computes, either all or one in particular if implemented
|
||||
explaining_cfg.metrics = "all"
|
||||
explaining_cfg.metrics = CN()
|
||||
explaining_cfg.metrics.type = "all"
|
||||
|
||||
# Whether or not recomputing metrics if they already exist
|
||||
explaining_cfg.metrics.force = False
|
||||
|
|
|
@ -86,7 +86,7 @@ class GraphXAIWrapper(ExplainerAlgorithm):
|
|||
if criterion == "mse":
|
||||
loss = MSELoss()
|
||||
return loss
|
||||
elif criterion == "cross-entropy":
|
||||
elif criterion == "cross_entropy":
|
||||
loss = CrossEntropyLoss()
|
||||
return loss
|
||||
else:
|
||||
|
|
|
@ -5,13 +5,6 @@ def parse_args() -> argparse.Namespace:
|
|||
r"""Parses the command line arguments."""
|
||||
parser = argparse.ArgumentParser(description="ExplainingFramework")
|
||||
|
||||
parser.add_argument(
|
||||
"--cfg",
|
||||
dest="cfg_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The configuration file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--explaining_cfg",
|
||||
dest="explaining_cfg_file",
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
class ExplainingOutline(object):
|
||||
def __init__(self, explaining_cfg: CN, explainer_cfg: CN = None):
|
||||
self.explaining_cfg = explaining_cfg
|
||||
self.explainer_cfg = explainer_cfg
|
||||
|
||||
|
||||
|
||||
def load_cfg(self):
|
||||
if self
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
import copy
|
||||
|
||||
from eixgnn.eixgnn import EiXGNN
|
||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||
eixgnn_cfg
|
||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||
from explaining_framework.config.explaining_config import explaining_cfg
|
||||
from explaining_framework.explainers.wrappers.from_captum import CaptumWrapper
|
||||
from explaining_framework.explainers.wrappers.from_graphxai import \
|
||||
GraphXAIWrapper
|
||||
from explaining_framework.metric.accuracy import Accuracy
|
||||
from explaining_framework.metric.fidelity import Fidelity
|
||||
from explaining_framework.metric.robust import Attack
|
||||
from explaining_framework.metric.sparsity import Sparsity
|
||||
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
|
||||
_load_ckpt)
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
|
||||
all__captum = [
|
||||
"LRP",
|
||||
"DeepLift",
|
||||
"DeepLiftShap",
|
||||
"FeatureAblation",
|
||||
"FeaturePermutation",
|
||||
"GradientShap",
|
||||
"GuidedBackprop",
|
||||
"GuidedGradCam",
|
||||
"InputXGradient",
|
||||
"IntegratedGradients",
|
||||
"Lime",
|
||||
"Occlusion",
|
||||
"Saliency",
|
||||
]
|
||||
|
||||
all__graphxai = [
|
||||
"CAM",
|
||||
"GradCAM",
|
||||
"GNN_LRP",
|
||||
"GradExplainer",
|
||||
"GuidedBackPropagation",
|
||||
"IntegratedGradients",
|
||||
"PGExplainer",
|
||||
"PGMExplainer",
|
||||
"RandomExplainer",
|
||||
"SubgraphX",
|
||||
"GraphMASK",
|
||||
]
|
||||
|
||||
all__own = ["EIXGNN", "SCGNN"]
|
||||
|
||||
|
||||
class ExplainingOutline(object):
|
||||
def __init__(self, explaining_cfg_path: str):
|
||||
self.explaining_cfg_path = explaining_cfg_path
|
||||
self.explaining_cfg = None
|
||||
self.explainer_cfg_path = None
|
||||
self.explainer_cfg = None
|
||||
self.explaining_algorithm = None
|
||||
self.cfg = None
|
||||
self.model = None
|
||||
self.dataset = None
|
||||
self.model_info = None
|
||||
|
||||
self.load_explaining_cfg()
|
||||
self.load_model_info()
|
||||
self.load_cfg()
|
||||
self.load_dataset()
|
||||
self.load_model()
|
||||
self.load_explainer_cfg()
|
||||
self.load_explainer()
|
||||
|
||||
def load_model_info(self):
|
||||
info = LoadModelInfo(
|
||||
dataset_name=self.explaining_cfg.dataset.name,
|
||||
model_dir=self.explaining_cfg.model.path,
|
||||
which=self.explaining_cfg.model.ckpt,
|
||||
)
|
||||
self.model_info = info.set_info()
|
||||
|
||||
def load_cfg(self):
|
||||
cfg.set_new_allowed(True)
|
||||
cfg.merge_from_file(self.model_info["cfg_path"])
|
||||
self.cfg = cfg
|
||||
|
||||
def load_explaining_cfg(self):
|
||||
explaining_cfg.set_new_allowed(True)
|
||||
explaining_cfg.merge_from_file(self.explaining_cfg_path)
|
||||
self.explaining_cfg = explaining_cfg
|
||||
|
||||
def load_explainer_cfg(self):
|
||||
if self.explaining_cfg is None:
|
||||
self.explaining_cfg()
|
||||
else:
|
||||
if self.explaining_cfg.explainer.cfg == "default":
|
||||
if self.explaining_cfg.explainer.name == "EIXGNN":
|
||||
self.explainer_cfg = copy.copy(eixgnn_cfg)
|
||||
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||
self.explainer_cfg = copy.copy(scgnn_cfg)
|
||||
else:
|
||||
self.explainer_cfg = None
|
||||
else:
|
||||
if self.explaining_cfg.explainer.name == "EIXGNN":
|
||||
eixgnn_cfg.set_new_allowed(True)
|
||||
eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||
self.explainer_cfg = eixgnn_cfg
|
||||
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||
scgnn_cfg.set_new_allowed(True)
|
||||
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||
self.explainer_cfg = scgnn_cfg
|
||||
|
||||
def load_model(self):
|
||||
if self.cfg is None:
|
||||
self.load_cfg()
|
||||
auto_select_device()
|
||||
self.model = create_model()
|
||||
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
|
||||
if self.model is None:
|
||||
raise ValueError("Model ckpt has not been loaded, ckpt file not found")
|
||||
|
||||
def load_dataset(self):
|
||||
if self.cfg is None:
|
||||
self.load_cfg()
|
||||
if self.explaining_cfg is None:
|
||||
self.explaining_cfg()
|
||||
if self.explaining_cfg.dataset.name != self.cfg.dataset.name:
|
||||
raise ValueError(
|
||||
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
||||
)
|
||||
self.dataset = create_dataset()
|
||||
|
||||
def load_explainer(self):
|
||||
self.load_explainer_cfg()
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if self.dataset is None:
|
||||
self.load_dataset()
|
||||
|
||||
name = self.explaining_cfg.explainer.name
|
||||
if name in all__captum:
|
||||
explaining_algorithm = CaptumWrapper(name)
|
||||
elif name in all__graphxai:
|
||||
explaining_algorithm = GraphXAIWrapper(
|
||||
name,
|
||||
in_channels=self.dataset.num_classes,
|
||||
criterion=self.cfg.model.loss_fun,
|
||||
)
|
||||
elif name in all__own:
|
||||
if name == "EIXGNN":
|
||||
explaining_algorithm = EiXGNN(
|
||||
L=self.explainer_cfg.L,
|
||||
p=self.explainer_cfg.p,
|
||||
importance_sampling_strategy=self.explainer_cfg.importance_sampling_strategy,
|
||||
domain_similarity=self.explainer_cfg.domain_similarity,
|
||||
signal_similarity=self.explainer_cfg.signal_similarity,
|
||||
shap_val_approx=self.explainer_cfg.shapley_value_approx,
|
||||
)
|
||||
elif name == "SCGNN":
|
||||
explaining_algorithm = SCGNN(
|
||||
depth=self.explainer_cfg.depth,
|
||||
interest_map_norm=self.explainer_cfg.interest_map_norm,
|
||||
score_map_norm=self.explainer_cfg.score_map_norm,
|
||||
)
|
||||
self.explaining_algorithm = explaining_algorithm
|
||||
print(self.explaining_algorithm.__dict__)
|
||||
|
||||
|
||||
PATH = "config_exp.yaml"
|
||||
test = ExplainingOutline(explaining_cfg_path=PATH)
|
|
@ -23,78 +23,104 @@ def _load_ckpt(
|
|||
) -> torch.nn.Module:
|
||||
r"""Loads the model at given checkpoint."""
|
||||
|
||||
if not osp.exists(path):
|
||||
if not os.path.exists(ckpt_path):
|
||||
return None
|
||||
|
||||
ckpt = torch.load(ckpt_path)
|
||||
model.load_state_dict(ckpt[MODEL_STATE])
|
||||
model.load_state_dict(ckpt["state_dict"])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def xp_stats(path_to_xp: str, wrt_metric: str = "val") -> str:
|
||||
acc = []
|
||||
for path in glob.glob(os.path.join(path_to_xp, "[0-9]", wrt_metric, "stats.json")):
|
||||
stats = json_to_dict_list(path)
|
||||
for stat in stats:
|
||||
acc.append(
|
||||
{"path": path, "epoch": stat["epoch"], "accuracy": stat["accuracy"]}
|
||||
class LoadModelInfo(object):
|
||||
def __init__(
|
||||
self, dataset_name: str, model_dir: str, which: str, wrt_metric: str = "val"
|
||||
):
|
||||
self.dataset_name = dataset_name
|
||||
self.model_dir = model_dir
|
||||
self.which = which
|
||||
self.wrt_metric = wrt_metric
|
||||
self.info = None
|
||||
|
||||
def list_stats(self, path) -> list:
|
||||
info = []
|
||||
for path in glob.glob(
|
||||
os.path.join(path, "[0-9]", self.wrt_metric, "stats.json")
|
||||
):
|
||||
stats = json_to_dict_list(path)
|
||||
for stat in stats:
|
||||
xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path)))
|
||||
ckpt_dir_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(path)), "ckpt"
|
||||
)
|
||||
cfg_path = os.path.join(xp_dir_path, "config.yaml")
|
||||
epoch = stat["epoch"]
|
||||
accuracy = stat["accuracy"]
|
||||
loss = stat["loss"]
|
||||
lr = stat["lr"]
|
||||
params = stat["params"]
|
||||
time_iter = stat["time_iter"]
|
||||
ckpt_path = self.get_ckpt_path(epoch=epoch, ckpt_dir_path=ckpt_dir_path)
|
||||
info.append(
|
||||
{
|
||||
"xp_dir_path": xp_dir_path,
|
||||
"ckpt_path": self.get_ckpt_path(
|
||||
epoch=epoch, ckpt_dir_path=ckpt_dir_path
|
||||
),
|
||||
"cfg_path": cfg_path,
|
||||
"epoch": epoch,
|
||||
"accuracy": accuracy,
|
||||
"loss": loss,
|
||||
"lr": lr,
|
||||
"params": params,
|
||||
"time_iter": time_iter,
|
||||
}
|
||||
)
|
||||
return info
|
||||
|
||||
def list_xp(self):
|
||||
paths = []
|
||||
for path in glob.glob(os.path.join(self.model_dir, "**", "config.yaml")):
|
||||
file = self.load_cfg(path)
|
||||
dataset_name_ = file["dataset"]["name"]
|
||||
if self.dataset_name == dataset_name_:
|
||||
paths.append(os.path.dirname(path))
|
||||
return paths
|
||||
|
||||
def load_cfg(self, config_path):
|
||||
return read_yaml(config_path)
|
||||
|
||||
def set_info(self):
|
||||
if self.which in ["best", "worst"] or isinstance(self.which, int):
|
||||
paths = self.list_xp()
|
||||
infos = []
|
||||
for path in paths:
|
||||
info = self.list_stats(path)
|
||||
infos.extend(info)
|
||||
infos = sorted(infos, key=lambda item: item["accuracy"])
|
||||
if self.which == "best":
|
||||
self.info = infos[-1]
|
||||
elif self.which == "worst":
|
||||
self.info = infos[0]
|
||||
elif isinstance(self.which, int):
|
||||
self.info = infos[self.which]
|
||||
else:
|
||||
specific_xp_dir_path = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(self.which))
|
||||
)
|
||||
return acc
|
||||
stats = self.list_stats(specific_xp_dir_path)
|
||||
self.info = [item for item in stats if item["ckpt_path"] == self.which][0]
|
||||
return self.info
|
||||
|
||||
|
||||
def xp_parser_dataset(dataset_name: str, models_dir_path) -> str:
|
||||
paths = []
|
||||
for path in glob.glob(os.path.join(models_dir_path, "**", "config.yaml")):
|
||||
file = read_yaml(path)
|
||||
dataset_name_ = file["dataset"]["name"]
|
||||
if dataset_name == dataset_name_:
|
||||
paths.append(os.path.dirname(path))
|
||||
return paths
|
||||
|
||||
|
||||
def best_xp_ckpt(paths, which: str = "best"):
|
||||
acc = []
|
||||
for path in paths:
|
||||
accuracies = xp_stats(path)
|
||||
acc.extend(accuracies)
|
||||
acc = sorted(acc, key=lambda item: item["accuracy"])
|
||||
if which == "best":
|
||||
return acc[-1]
|
||||
elif which == "worst":
|
||||
return acc[0]
|
||||
|
||||
|
||||
def stats_to_ckpt(parse):
|
||||
paths = os.path.join(
|
||||
os.path.dirname(os.path.dirname(parse["path"])), "ckpt", "*.ckpt"
|
||||
)
|
||||
ckpts = []
|
||||
for path in glob.glob(paths):
|
||||
feat = os.path.basename(path)
|
||||
feat = feat.split("-")
|
||||
for fe in feat:
|
||||
if "epoch" in fe:
|
||||
fe_ep = int(fe.split("=")[1])
|
||||
ckpts.append({"path": path, "div": abs(fe_ep - parse["epoch"])})
|
||||
return sorted(ckpts, key=lambda item: item["div"])[0]
|
||||
|
||||
|
||||
def stats_to_cfg(parse):
|
||||
path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(parse["path"])))
|
||||
)
|
||||
path = os.path.join(path, "config.yaml")
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
else:
|
||||
raise FileNotFoundError(f"{path} does not exists")
|
||||
|
||||
|
||||
PATH = "/home/SIC/araison/test_ggym/pytorch_geometric/graphgym/results/"
|
||||
best = best_xp_ckpt(
|
||||
paths=xp_parser_dataset(dataset_name="CIFAR10", models_dir_path=PATH), which="worst"
|
||||
)
|
||||
print(best)
|
||||
print(stats_to_ckpt(best))
|
||||
print(stats_to_cfg(best))
|
||||
def get_ckpt_path(self, epoch: int, ckpt_dir_path: str):
|
||||
paths = os.path.join(ckpt_dir_path, "*.ckpt")
|
||||
ckpts = []
|
||||
for path in glob.glob(paths):
|
||||
feat = os.path.basename(path)
|
||||
feat = feat.split("-")
|
||||
for fe in feat:
|
||||
if "epoch" in fe:
|
||||
fe_ep = int(fe.split("=")[1])
|
||||
ckpts.append({"path": path, "div": abs(fe_ep - epoch)})
|
||||
ckpt = sorted(ckpts, key=lambda item: item["div"])[0]
|
||||
return ckpt["path"]
|
||||
|
|
Loading…
Reference in New Issue