Adding new features and first draft of main.py

This commit is contained in:
araison 2022-12-30 19:34:41 +01:00
parent a00e73d4f0
commit 074ff25c83
5 changed files with 203 additions and 17 deletions

View File

@ -49,11 +49,4 @@ class Metric(ABC):
return out
def save_config(self, path) -> None:
config = {k: getattr(self, k) for k in dir(self)}
config = {
k: v
for k, v in config.items()
if isinstance(v, (int, float, str, bool)) or v is None
}
write_json(config, path)

View File

@ -7,11 +7,12 @@ import logging
import os
import torch
from explaining_framework.utils.io import read_yaml
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule
from torch_geometric.graphgym.utils.io import json_to_dict_list
from explaining_framework.utils.io import read_yaml
MODEL_STATE = "model_state"
OPTIMIZER_STATE = "optimizer_state"
SCHEDULER_STATE = "scheduler_state"
@ -44,14 +45,15 @@ class LoadModelInfo(object):
def list_stats(self, path) -> list:
info = []
for path in glob.glob(
for path_ in glob.glob(
os.path.join(path, "[0-9]", self.wrt_metric, "stats.json")
):
stats = json_to_dict_list(path)
stats = json_to_dict_list(path_)
for stat in stats:
xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path)))
xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path_)))
seed = int(os.path.basename(os.path.dirname(os.path.dirname(path_))))
ckpt_dir_path = os.path.join(
os.path.dirname(os.path.dirname(path)), "ckpt"
os.path.dirname(os.path.dirname(path_)), "ckpt"
)
cfg_path = os.path.join(xp_dir_path, "config.yaml")
epoch = stat["epoch"]
@ -68,12 +70,16 @@ class LoadModelInfo(object):
epoch=epoch, ckpt_dir_path=ckpt_dir_path
),
"cfg_path": cfg_path,
"seed": seed,
"epoch": epoch,
"accuracy": accuracy,
"loss": loss,
"lr": lr,
"params": params,
"time_iter": time_iter,
"which": self.which
if self.which in ["best", "worst"]
else None,
}
)
return info
@ -112,6 +118,22 @@ class LoadModelInfo(object):
self.info = [item for item in stats if item["ckpt_path"] == self.which][0]
return self.info
def get_model_signature(self):
if self.info is None:
self.set_info()
model_name = os.path.basename(self.info["xp_dir_name"])
model_seed = self.info["seed"]
epoch = os.path.basename(self.info["ckpt_path"])
model_signature = "-".join(
[
f"{name}={val}"
for name, val in zip(["name", "seed"], [model_name, model_seed])
]
+ [epoch]
)
return model_signature
def get_ckpt_path(self, epoch: int, ckpt_dir_path: str):
paths = os.path.join(ckpt_dir_path, "*.ckpt")
ckpts = []

View File

@ -4,6 +4,7 @@ from typing import Any
from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric.data import Batch, Data
from torch_geometric.data.loader.dataloader import DataLoader
from torch_geometric.explain import Explainer
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
@ -62,6 +63,8 @@ all_fidelity = [
"fidelity_plus_prob",
"fidelity_minus_prob",
"infidelity_KL",
"characterization",
"characterization_prob",
]
all_accuracy = [
"precision_score",
@ -94,6 +97,7 @@ class ExplainingOutline(object):
self.model_info = None
self.metrics = None
self.attacks = None
self.model_signature = None
self.load_explaining_cfg()
self.load_model_info()
@ -112,6 +116,7 @@ class ExplainingOutline(object):
which=self.explaining_cfg.model.ckpt,
)
self.model_info = info.set_info()
self.model_signature = info.get_model_signature()
def load_cfg(self):
cfg.set_new_allowed(True)
@ -166,6 +171,7 @@ class ExplainingOutline(object):
if isinstance(self.explaining_cfg.dataset.specific_items, int):
ind = self.explaining_cfg.dataset.specific_items
self.dataset = self.dataset[ind : ind + 1]
self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1)
def load_explainer(self):
self.load_explainer_cfg()
@ -199,6 +205,10 @@ class ExplainingOutline(object):
interest_map_norm=self.explainer_cfg.interest_map_norm,
score_map_norm=self.explainer_cfg.score_map_norm,
)
elif name is None:
explaining_algorithm = None
else:
raise ValueError(f"{name_} Metric is not supported yet")
self.explaining_algorithm = explaining_algorithm
def load_metric(self):
@ -228,6 +238,10 @@ class ExplainingOutline(object):
raise ValueError(
f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation"
)
elif name_ is None:
self.metrics = []
else:
raise ValueError(f"{name_} Metric is not supported yet")
def load_attack(self):
if self.cfg is None:
@ -238,5 +252,9 @@ class ExplainingOutline(object):
if name_ == "all":
all_rob_metrics = [Attack(name) for name in all_robust]
self.attacks = all_rob_metrics
if name_ in all_robust:
elif name_ in all_robust:
self.attacks = [Attack(name_)]
elif name_ is None:
slef.attacks = []
else:
raise ValueError(f"{name_} is an Attack method that is not supported yet")

View File

@ -1,5 +1,6 @@
import json
import os
import yaml
@ -24,6 +25,26 @@ def write_yaml(data: dict, path: str) -> None:
with open(path, "w") as f:
data = yaml.dump(data, f)
def is_exists(path:str)-> bool:
return os.path.exists(path)
def is_exists(path: str) -> bool:
return os.path.exists(path)
def get_obj_config(obj):
config = {k: getattr(obj, k) for k in dir(obj)}
config = {
k: v
for k, v in config.items()
if isinstance(v, (int, float, str, bool)) or v is None
}
return config
def save_obj_config(obj, path) -> None:
config = get_obj_config(obj)
write_json(config, path)
def obj_config_to_str(obj) -> str:
config = get_obj_config(obj)
return "-".join([f"{k}={v}" for k, v in config.items()])

134
main.py
View File

@ -3,6 +3,138 @@
#
import os
import time
from torch_geometric import seed_everything
from torch_geometric.data.makedirs import makedirs
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.utils.device import auto_select_device
from explaining_framework.config.explaining_config import explaining_cfg
from explaining_framework.utils.explaining.cmd_args import parse_args
from explaining_framework.utils.explaining.outline import parse_args
from explaining_framework.utils.explaining.outline import ExplainingOutline
from explaining_framework.utils.io import (obj_config_to_str, read_json,
write_json, write_yaml)
from explaining_framework.utils.explanation.adjust import Adjust
# inference, time, force,
def get_pred(explanation,force=False):
dict_ = explanation.to_dict()
if dict_.get('pred') is None or dict_.get('pred_masked') or force:
pred = explainer.get_prediction(explanation)
pred_masked = explainer.get_masked_prediction(x=explanation.x,edge_index=explanation.edge_index,node_mask=explanation.node_mask,edge_mask=explanation.edge_mask)
explanation.__setattr__('pred',pred)
explanation.__setattr__('pred_masked',pred_masked)
return explanation
else:
return explanation
if __name__ == "__main__":
args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file)
auto_select_device()
# Load components
dataset = outline.dataset.to(cfg.accelerator)
model = outline.model.to(cfg.accelerator)
model_info = outline.model_info
metrics = outline.metrics
explaining_algorithm = outline.explaining_algorithm
attacks = outline.attacks
explainer_cfg = outline.explainer_cfg
model_signature = outline.model_signature
# Set seed
seed_everything(explaining_cfg.seed)
# Global path
global_path = os.path.join(explaining_cfg.out_dir, model_signature)
makedirs(global_path)
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
write_json(model_info, os.path.join(global_path, "info.json"))
global_path = os.path.join(global_path, explaining_cfg.explainer.name+'_'+obj_config_to_str(explaining_algorithm))
makedirs(global_path)
write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm))
makedirs(global_path)
explainer = Explainer(
model=model,
algorithm=explaining_algorithm,
explainer_config=dict(
explanation_type=explaining_cfg.explanation_type,
node_mask_type="object",
edge_mask_type="object",
),
model_config=dict(
mode="regression",
task_level=cfg.dataset.task,
return_type=explaining_cfg.model_config.return_type,
),
)
# Save explaining configuration
for index, item in enumerate(dataset):
save_raw_path = os.path.join(global_path, "raw")
makedirs(save_raw_path)
explanation_path = os.path.join(save_raw_path, f"{index}.json")
if is_exists(explanation_path):
if explaining_cfg.explainer.force:
explanation = explainer(
x=item.x,
edge_index=item.edge_index,
index=item.y,
target=item.y,
)
else:
explanation = load_explanation(explanation_path)
else:
explanation = explainer(
x=item.x,
edge_index=item.edge_index,
index=item.y,
target=item.y,
)
explanation = get_pred(explanation,force=False)
save_explanation(explanation,explanation_path)
for apply_relu in [True,False]:
for apply_absolute in [True,False]:
adjust = Adjust(apply_relu=apply_relu,apply_absolute=apply_absolute)
save_raw_path = os.path.join(global_path,f'adjust-{obj_config_to_str(adjust)}')
makedirs(save_raw_path)
explanation = adjust.forward(explanation)
explanation_path = os.path.join(save_raw_path, f"{index}.json")
explanation = get_pred(explanation,force=True)
save_explanation(explanation,explanation_path)
for threshold_approach in ['hard','topk','topk_hard']:
for threshold_value in explaining_cfg.threshold_config.value:
masking_path =os.path.join(save_raw_path,f'threshold={threshold_approach}-value={value}')
exp_threshold_path = os.path.join(masking_path,f'{index}.json')
if is_exists(exp_threshold_path):
explanation = load_explanation(exp_threshold_path)
else:
threshold_conf = {'threshold_type':threshold_approach,'value':threshold_value}
explainer.threshold_config = ThresholdConfig.cast(threshold_conf)
expl = copy.copy(explanation)
exp_threshold = explainer._post_process(expl)
exp_threshold= get_pred(exp_threshold,force=True)
save_explanation(exp_threshold,exp_threshold_path)
for metric in metrics:
metric_path =os.path.join(masking_path,f'{obj_config_to_str(metric)}')
if is_exists(os.path.join(metric_path,f'{index}.json')):
continue
else:
out = metric.forward(exp_threshold)
write_json({f'{metric.name}':out})