Adding new features and first draft of main.py
This commit is contained in:
parent
a00e73d4f0
commit
074ff25c83
@ -49,11 +49,4 @@ class Metric(ABC):
|
||||
|
||||
return out
|
||||
|
||||
def save_config(self, path) -> None:
|
||||
config = {k: getattr(self, k) for k in dir(self)}
|
||||
config = {
|
||||
k: v
|
||||
for k, v in config.items()
|
||||
if isinstance(v, (int, float, str, bool)) or v is None
|
||||
}
|
||||
write_json(config, path)
|
||||
|
||||
|
@ -7,11 +7,12 @@ import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
from torch_geometric.graphgym.model_builder import create_model
|
||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||
|
||||
from explaining_framework.utils.io import read_yaml
|
||||
|
||||
MODEL_STATE = "model_state"
|
||||
OPTIMIZER_STATE = "optimizer_state"
|
||||
SCHEDULER_STATE = "scheduler_state"
|
||||
@ -44,14 +45,15 @@ class LoadModelInfo(object):
|
||||
|
||||
def list_stats(self, path) -> list:
|
||||
info = []
|
||||
for path in glob.glob(
|
||||
for path_ in glob.glob(
|
||||
os.path.join(path, "[0-9]", self.wrt_metric, "stats.json")
|
||||
):
|
||||
stats = json_to_dict_list(path)
|
||||
stats = json_to_dict_list(path_)
|
||||
for stat in stats:
|
||||
xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path)))
|
||||
xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path_)))
|
||||
seed = int(os.path.basename(os.path.dirname(os.path.dirname(path_))))
|
||||
ckpt_dir_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(path)), "ckpt"
|
||||
os.path.dirname(os.path.dirname(path_)), "ckpt"
|
||||
)
|
||||
cfg_path = os.path.join(xp_dir_path, "config.yaml")
|
||||
epoch = stat["epoch"]
|
||||
@ -68,12 +70,16 @@ class LoadModelInfo(object):
|
||||
epoch=epoch, ckpt_dir_path=ckpt_dir_path
|
||||
),
|
||||
"cfg_path": cfg_path,
|
||||
"seed": seed,
|
||||
"epoch": epoch,
|
||||
"accuracy": accuracy,
|
||||
"loss": loss,
|
||||
"lr": lr,
|
||||
"params": params,
|
||||
"time_iter": time_iter,
|
||||
"which": self.which
|
||||
if self.which in ["best", "worst"]
|
||||
else None,
|
||||
}
|
||||
)
|
||||
return info
|
||||
@ -112,6 +118,22 @@ class LoadModelInfo(object):
|
||||
self.info = [item for item in stats if item["ckpt_path"] == self.which][0]
|
||||
return self.info
|
||||
|
||||
def get_model_signature(self):
|
||||
if self.info is None:
|
||||
self.set_info()
|
||||
|
||||
model_name = os.path.basename(self.info["xp_dir_name"])
|
||||
model_seed = self.info["seed"]
|
||||
epoch = os.path.basename(self.info["ckpt_path"])
|
||||
model_signature = "-".join(
|
||||
[
|
||||
f"{name}={val}"
|
||||
for name, val in zip(["name", "seed"], [model_name, model_seed])
|
||||
]
|
||||
+ [epoch]
|
||||
)
|
||||
return model_signature
|
||||
|
||||
def get_ckpt_path(self, epoch: int, ckpt_dir_path: str):
|
||||
paths = os.path.join(ckpt_dir_path, "*.ckpt")
|
||||
ckpts = []
|
||||
|
@ -4,6 +4,7 @@ from typing import Any
|
||||
from eixgnn.eixgnn import EiXGNN
|
||||
from scgnn.scgnn import SCGNN
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.data.loader.dataloader import DataLoader
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.loader import create_dataset
|
||||
@ -62,6 +63,8 @@ all_fidelity = [
|
||||
"fidelity_plus_prob",
|
||||
"fidelity_minus_prob",
|
||||
"infidelity_KL",
|
||||
"characterization",
|
||||
"characterization_prob",
|
||||
]
|
||||
all_accuracy = [
|
||||
"precision_score",
|
||||
@ -94,6 +97,7 @@ class ExplainingOutline(object):
|
||||
self.model_info = None
|
||||
self.metrics = None
|
||||
self.attacks = None
|
||||
self.model_signature = None
|
||||
|
||||
self.load_explaining_cfg()
|
||||
self.load_model_info()
|
||||
@ -112,6 +116,7 @@ class ExplainingOutline(object):
|
||||
which=self.explaining_cfg.model.ckpt,
|
||||
)
|
||||
self.model_info = info.set_info()
|
||||
self.model_signature = info.get_model_signature()
|
||||
|
||||
def load_cfg(self):
|
||||
cfg.set_new_allowed(True)
|
||||
@ -166,6 +171,7 @@ class ExplainingOutline(object):
|
||||
if isinstance(self.explaining_cfg.dataset.specific_items, int):
|
||||
ind = self.explaining_cfg.dataset.specific_items
|
||||
self.dataset = self.dataset[ind : ind + 1]
|
||||
self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1)
|
||||
|
||||
def load_explainer(self):
|
||||
self.load_explainer_cfg()
|
||||
@ -199,6 +205,10 @@ class ExplainingOutline(object):
|
||||
interest_map_norm=self.explainer_cfg.interest_map_norm,
|
||||
score_map_norm=self.explainer_cfg.score_map_norm,
|
||||
)
|
||||
elif name is None:
|
||||
explaining_algorithm = None
|
||||
else:
|
||||
raise ValueError(f"{name_} Metric is not supported yet")
|
||||
self.explaining_algorithm = explaining_algorithm
|
||||
|
||||
def load_metric(self):
|
||||
@ -228,6 +238,10 @@ class ExplainingOutline(object):
|
||||
raise ValueError(
|
||||
f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation"
|
||||
)
|
||||
elif name_ is None:
|
||||
self.metrics = []
|
||||
else:
|
||||
raise ValueError(f"{name_} Metric is not supported yet")
|
||||
|
||||
def load_attack(self):
|
||||
if self.cfg is None:
|
||||
@ -238,5 +252,9 @@ class ExplainingOutline(object):
|
||||
if name_ == "all":
|
||||
all_rob_metrics = [Attack(name) for name in all_robust]
|
||||
self.attacks = all_rob_metrics
|
||||
if name_ in all_robust:
|
||||
elif name_ in all_robust:
|
||||
self.attacks = [Attack(name_)]
|
||||
elif name_ is None:
|
||||
slef.attacks = []
|
||||
else:
|
||||
raise ValueError(f"{name_} is an Attack method that is not supported yet")
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@ -24,6 +25,26 @@ def write_yaml(data: dict, path: str) -> None:
|
||||
with open(path, "w") as f:
|
||||
data = yaml.dump(data, f)
|
||||
|
||||
def is_exists(path:str)-> bool:
|
||||
return os.path.exists(path)
|
||||
|
||||
def is_exists(path: str) -> bool:
|
||||
return os.path.exists(path)
|
||||
|
||||
|
||||
def get_obj_config(obj):
|
||||
config = {k: getattr(obj, k) for k in dir(obj)}
|
||||
config = {
|
||||
k: v
|
||||
for k, v in config.items()
|
||||
if isinstance(v, (int, float, str, bool)) or v is None
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def save_obj_config(obj, path) -> None:
|
||||
config = get_obj_config(obj)
|
||||
write_json(config, path)
|
||||
|
||||
|
||||
def obj_config_to_str(obj) -> str:
|
||||
config = get_obj_config(obj)
|
||||
return "-".join([f"{k}={v}" for k, v in config.items()])
|
||||
|
134
main.py
134
main.py
@ -3,6 +3,138 @@
|
||||
#
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from torch_geometric import seed_everything
|
||||
from torch_geometric.data.makedirs import makedirs
|
||||
from torch_geometric.explain import Explainer
|
||||
from torch_geometric.explain.config import ThresholdConfig
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||
|
||||
from explaining_framework.config.explaining_config import explaining_cfg
|
||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||
from explaining_framework.utils.explaining.outline import parse_args
|
||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||
from explaining_framework.utils.io import (obj_config_to_str, read_json,
|
||||
write_json, write_yaml)
|
||||
from explaining_framework.utils.explanation.adjust import Adjust
|
||||
|
||||
# inference, time, force,
|
||||
|
||||
def get_pred(explanation,force=False):
|
||||
dict_ = explanation.to_dict()
|
||||
if dict_.get('pred') is None or dict_.get('pred_masked') or force:
|
||||
pred = explainer.get_prediction(explanation)
|
||||
pred_masked = explainer.get_masked_prediction(x=explanation.x,edge_index=explanation.edge_index,node_mask=explanation.node_mask,edge_mask=explanation.edge_mask)
|
||||
explanation.__setattr__('pred',pred)
|
||||
explanation.__setattr__('pred_masked',pred_masked)
|
||||
return explanation
|
||||
else:
|
||||
return explanation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||
auto_select_device()
|
||||
|
||||
# Load components
|
||||
dataset = outline.dataset.to(cfg.accelerator)
|
||||
model = outline.model.to(cfg.accelerator)
|
||||
model_info = outline.model_info
|
||||
metrics = outline.metrics
|
||||
explaining_algorithm = outline.explaining_algorithm
|
||||
attacks = outline.attacks
|
||||
explainer_cfg = outline.explainer_cfg
|
||||
model_signature = outline.model_signature
|
||||
|
||||
# Set seed
|
||||
seed_everything(explaining_cfg.seed)
|
||||
|
||||
# Global path
|
||||
global_path = os.path.join(explaining_cfg.out_dir, model_signature)
|
||||
makedirs(global_path)
|
||||
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
|
||||
write_json(model_info, os.path.join(global_path, "info.json"))
|
||||
|
||||
global_path = os.path.join(global_path, explaining_cfg.explainer.name+'_'+obj_config_to_str(explaining_algorithm))
|
||||
makedirs(global_path)
|
||||
write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
|
||||
write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
|
||||
|
||||
global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm))
|
||||
makedirs(global_path)
|
||||
explainer = Explainer(
|
||||
model=model,
|
||||
algorithm=explaining_algorithm,
|
||||
explainer_config=dict(
|
||||
explanation_type=explaining_cfg.explanation_type,
|
||||
node_mask_type="object",
|
||||
edge_mask_type="object",
|
||||
),
|
||||
model_config=dict(
|
||||
mode="regression",
|
||||
task_level=cfg.dataset.task,
|
||||
return_type=explaining_cfg.model_config.return_type,
|
||||
),
|
||||
)
|
||||
# Save explaining configuration
|
||||
for index, item in enumerate(dataset):
|
||||
save_raw_path = os.path.join(global_path, "raw")
|
||||
makedirs(save_raw_path)
|
||||
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
||||
|
||||
if is_exists(explanation_path):
|
||||
if explaining_cfg.explainer.force:
|
||||
explanation = explainer(
|
||||
x=item.x,
|
||||
edge_index=item.edge_index,
|
||||
index=item.y,
|
||||
target=item.y,
|
||||
)
|
||||
else:
|
||||
explanation = load_explanation(explanation_path)
|
||||
else:
|
||||
explanation = explainer(
|
||||
x=item.x,
|
||||
edge_index=item.edge_index,
|
||||
index=item.y,
|
||||
target=item.y,
|
||||
)
|
||||
explanation = get_pred(explanation,force=False)
|
||||
save_explanation(explanation,explanation_path)
|
||||
for apply_relu in [True,False]:
|
||||
for apply_absolute in [True,False]:
|
||||
adjust = Adjust(apply_relu=apply_relu,apply_absolute=apply_absolute)
|
||||
save_raw_path = os.path.join(global_path,f'adjust-{obj_config_to_str(adjust)}')
|
||||
makedirs(save_raw_path)
|
||||
explanation = adjust.forward(explanation)
|
||||
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
||||
explanation = get_pred(explanation,force=True)
|
||||
save_explanation(explanation,explanation_path)
|
||||
|
||||
for threshold_approach in ['hard','topk','topk_hard']:
|
||||
for threshold_value in explaining_cfg.threshold_config.value:
|
||||
|
||||
masking_path =os.path.join(save_raw_path,f'threshold={threshold_approach}-value={value}')
|
||||
exp_threshold_path = os.path.join(masking_path,f'{index}.json')
|
||||
|
||||
if is_exists(exp_threshold_path):
|
||||
explanation = load_explanation(exp_threshold_path)
|
||||
else:
|
||||
threshold_conf = {'threshold_type':threshold_approach,'value':threshold_value}
|
||||
explainer.threshold_config = ThresholdConfig.cast(threshold_conf)
|
||||
|
||||
expl = copy.copy(explanation)
|
||||
exp_threshold = explainer._post_process(expl)
|
||||
exp_threshold= get_pred(exp_threshold,force=True)
|
||||
|
||||
save_explanation(exp_threshold,exp_threshold_path)
|
||||
for metric in metrics:
|
||||
metric_path =os.path.join(masking_path,f'{obj_config_to_str(metric)}')
|
||||
if is_exists(os.path.join(metric_path,f'{index}.json')):
|
||||
continue
|
||||
else:
|
||||
out = metric.forward(exp_threshold)
|
||||
write_json({f'{metric.name}':out})
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user