Adding new features and first draft of main.py
This commit is contained in:
parent
a00e73d4f0
commit
074ff25c83
5 changed files with 203 additions and 17 deletions
|
@ -49,11 +49,4 @@ class Metric(ABC):
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def save_config(self, path) -> None:
|
|
||||||
config = {k: getattr(self, k) for k in dir(self)}
|
|
||||||
config = {
|
|
||||||
k: v
|
|
||||||
for k, v in config.items()
|
|
||||||
if isinstance(v, (int, float, str, bool)) or v is None
|
|
||||||
}
|
|
||||||
write_json(config, path)
|
|
||||||
|
|
|
@ -7,11 +7,12 @@ import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from explaining_framework.utils.io import read_yaml
|
|
||||||
from torch_geometric.graphgym.model_builder import create_model
|
from torch_geometric.graphgym.model_builder import create_model
|
||||||
from torch_geometric.graphgym.train import GraphGymDataModule
|
from torch_geometric.graphgym.train import GraphGymDataModule
|
||||||
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
from torch_geometric.graphgym.utils.io import json_to_dict_list
|
||||||
|
|
||||||
|
from explaining_framework.utils.io import read_yaml
|
||||||
|
|
||||||
MODEL_STATE = "model_state"
|
MODEL_STATE = "model_state"
|
||||||
OPTIMIZER_STATE = "optimizer_state"
|
OPTIMIZER_STATE = "optimizer_state"
|
||||||
SCHEDULER_STATE = "scheduler_state"
|
SCHEDULER_STATE = "scheduler_state"
|
||||||
|
@ -44,14 +45,15 @@ class LoadModelInfo(object):
|
||||||
|
|
||||||
def list_stats(self, path) -> list:
|
def list_stats(self, path) -> list:
|
||||||
info = []
|
info = []
|
||||||
for path in glob.glob(
|
for path_ in glob.glob(
|
||||||
os.path.join(path, "[0-9]", self.wrt_metric, "stats.json")
|
os.path.join(path, "[0-9]", self.wrt_metric, "stats.json")
|
||||||
):
|
):
|
||||||
stats = json_to_dict_list(path)
|
stats = json_to_dict_list(path_)
|
||||||
for stat in stats:
|
for stat in stats:
|
||||||
xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path)))
|
xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path_)))
|
||||||
|
seed = int(os.path.basename(os.path.dirname(os.path.dirname(path_))))
|
||||||
ckpt_dir_path = os.path.join(
|
ckpt_dir_path = os.path.join(
|
||||||
os.path.dirname(os.path.dirname(path)), "ckpt"
|
os.path.dirname(os.path.dirname(path_)), "ckpt"
|
||||||
)
|
)
|
||||||
cfg_path = os.path.join(xp_dir_path, "config.yaml")
|
cfg_path = os.path.join(xp_dir_path, "config.yaml")
|
||||||
epoch = stat["epoch"]
|
epoch = stat["epoch"]
|
||||||
|
@ -68,12 +70,16 @@ class LoadModelInfo(object):
|
||||||
epoch=epoch, ckpt_dir_path=ckpt_dir_path
|
epoch=epoch, ckpt_dir_path=ckpt_dir_path
|
||||||
),
|
),
|
||||||
"cfg_path": cfg_path,
|
"cfg_path": cfg_path,
|
||||||
|
"seed": seed,
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"accuracy": accuracy,
|
"accuracy": accuracy,
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
"lr": lr,
|
"lr": lr,
|
||||||
"params": params,
|
"params": params,
|
||||||
"time_iter": time_iter,
|
"time_iter": time_iter,
|
||||||
|
"which": self.which
|
||||||
|
if self.which in ["best", "worst"]
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return info
|
return info
|
||||||
|
@ -112,6 +118,22 @@ class LoadModelInfo(object):
|
||||||
self.info = [item for item in stats if item["ckpt_path"] == self.which][0]
|
self.info = [item for item in stats if item["ckpt_path"] == self.which][0]
|
||||||
return self.info
|
return self.info
|
||||||
|
|
||||||
|
def get_model_signature(self):
|
||||||
|
if self.info is None:
|
||||||
|
self.set_info()
|
||||||
|
|
||||||
|
model_name = os.path.basename(self.info["xp_dir_name"])
|
||||||
|
model_seed = self.info["seed"]
|
||||||
|
epoch = os.path.basename(self.info["ckpt_path"])
|
||||||
|
model_signature = "-".join(
|
||||||
|
[
|
||||||
|
f"{name}={val}"
|
||||||
|
for name, val in zip(["name", "seed"], [model_name, model_seed])
|
||||||
|
]
|
||||||
|
+ [epoch]
|
||||||
|
)
|
||||||
|
return model_signature
|
||||||
|
|
||||||
def get_ckpt_path(self, epoch: int, ckpt_dir_path: str):
|
def get_ckpt_path(self, epoch: int, ckpt_dir_path: str):
|
||||||
paths = os.path.join(ckpt_dir_path, "*.ckpt")
|
paths = os.path.join(ckpt_dir_path, "*.ckpt")
|
||||||
ckpts = []
|
ckpts = []
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Any
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
from scgnn.scgnn import SCGNN
|
from scgnn.scgnn import SCGNN
|
||||||
from torch_geometric.data import Batch, Data
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.data.loader.dataloader import DataLoader
|
||||||
from torch_geometric.explain import Explainer
|
from torch_geometric.explain import Explainer
|
||||||
from torch_geometric.graphgym.config import cfg
|
from torch_geometric.graphgym.config import cfg
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
|
@ -62,6 +63,8 @@ all_fidelity = [
|
||||||
"fidelity_plus_prob",
|
"fidelity_plus_prob",
|
||||||
"fidelity_minus_prob",
|
"fidelity_minus_prob",
|
||||||
"infidelity_KL",
|
"infidelity_KL",
|
||||||
|
"characterization",
|
||||||
|
"characterization_prob",
|
||||||
]
|
]
|
||||||
all_accuracy = [
|
all_accuracy = [
|
||||||
"precision_score",
|
"precision_score",
|
||||||
|
@ -94,6 +97,7 @@ class ExplainingOutline(object):
|
||||||
self.model_info = None
|
self.model_info = None
|
||||||
self.metrics = None
|
self.metrics = None
|
||||||
self.attacks = None
|
self.attacks = None
|
||||||
|
self.model_signature = None
|
||||||
|
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
self.load_model_info()
|
self.load_model_info()
|
||||||
|
@ -112,6 +116,7 @@ class ExplainingOutline(object):
|
||||||
which=self.explaining_cfg.model.ckpt,
|
which=self.explaining_cfg.model.ckpt,
|
||||||
)
|
)
|
||||||
self.model_info = info.set_info()
|
self.model_info = info.set_info()
|
||||||
|
self.model_signature = info.get_model_signature()
|
||||||
|
|
||||||
def load_cfg(self):
|
def load_cfg(self):
|
||||||
cfg.set_new_allowed(True)
|
cfg.set_new_allowed(True)
|
||||||
|
@ -166,6 +171,7 @@ class ExplainingOutline(object):
|
||||||
if isinstance(self.explaining_cfg.dataset.specific_items, int):
|
if isinstance(self.explaining_cfg.dataset.specific_items, int):
|
||||||
ind = self.explaining_cfg.dataset.specific_items
|
ind = self.explaining_cfg.dataset.specific_items
|
||||||
self.dataset = self.dataset[ind : ind + 1]
|
self.dataset = self.dataset[ind : ind + 1]
|
||||||
|
self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1)
|
||||||
|
|
||||||
def load_explainer(self):
|
def load_explainer(self):
|
||||||
self.load_explainer_cfg()
|
self.load_explainer_cfg()
|
||||||
|
@ -199,6 +205,10 @@ class ExplainingOutline(object):
|
||||||
interest_map_norm=self.explainer_cfg.interest_map_norm,
|
interest_map_norm=self.explainer_cfg.interest_map_norm,
|
||||||
score_map_norm=self.explainer_cfg.score_map_norm,
|
score_map_norm=self.explainer_cfg.score_map_norm,
|
||||||
)
|
)
|
||||||
|
elif name is None:
|
||||||
|
explaining_algorithm = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{name_} Metric is not supported yet")
|
||||||
self.explaining_algorithm = explaining_algorithm
|
self.explaining_algorithm = explaining_algorithm
|
||||||
|
|
||||||
def load_metric(self):
|
def load_metric(self):
|
||||||
|
@ -228,6 +238,10 @@ class ExplainingOutline(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation"
|
f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation"
|
||||||
)
|
)
|
||||||
|
elif name_ is None:
|
||||||
|
self.metrics = []
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{name_} Metric is not supported yet")
|
||||||
|
|
||||||
def load_attack(self):
|
def load_attack(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
|
@ -238,5 +252,9 @@ class ExplainingOutline(object):
|
||||||
if name_ == "all":
|
if name_ == "all":
|
||||||
all_rob_metrics = [Attack(name) for name in all_robust]
|
all_rob_metrics = [Attack(name) for name in all_robust]
|
||||||
self.attacks = all_rob_metrics
|
self.attacks = all_rob_metrics
|
||||||
if name_ in all_robust:
|
elif name_ in all_robust:
|
||||||
self.attacks = [Attack(name_)]
|
self.attacks = [Attack(name_)]
|
||||||
|
elif name_ is None:
|
||||||
|
slef.attacks = []
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{name_} is an Attack method that is not supported yet")
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,6 +25,26 @@ def write_yaml(data: dict, path: str) -> None:
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
data = yaml.dump(data, f)
|
data = yaml.dump(data, f)
|
||||||
|
|
||||||
def is_exists(path:str)-> bool:
|
|
||||||
return os.path.exists(path)
|
|
||||||
|
|
||||||
|
def is_exists(path: str) -> bool:
|
||||||
|
return os.path.exists(path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_obj_config(obj):
|
||||||
|
config = {k: getattr(obj, k) for k in dir(obj)}
|
||||||
|
config = {
|
||||||
|
k: v
|
||||||
|
for k, v in config.items()
|
||||||
|
if isinstance(v, (int, float, str, bool)) or v is None
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def save_obj_config(obj, path) -> None:
|
||||||
|
config = get_obj_config(obj)
|
||||||
|
write_json(config, path)
|
||||||
|
|
||||||
|
|
||||||
|
def obj_config_to_str(obj) -> str:
|
||||||
|
config = get_obj_config(obj)
|
||||||
|
return "-".join([f"{k}={v}" for k, v in config.items()])
|
||||||
|
|
134
main.py
134
main.py
|
@ -3,6 +3,138 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from torch_geometric import seed_everything
|
||||||
|
from torch_geometric.data.makedirs import makedirs
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
|
||||||
from explaining_framework.config.explaining_config import explaining_cfg
|
from explaining_framework.config.explaining_config import explaining_cfg
|
||||||
from explaining_framework.utils.explaining.cmd_args import parse_args
|
from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||||
from explaining_framework.utils.explaining.outline import parse_args
|
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||||
|
from explaining_framework.utils.io import (obj_config_to_str, read_json,
|
||||||
|
write_json, write_yaml)
|
||||||
|
from explaining_framework.utils.explanation.adjust import Adjust
|
||||||
|
|
||||||
|
# inference, time, force,
|
||||||
|
|
||||||
|
def get_pred(explanation,force=False):
|
||||||
|
dict_ = explanation.to_dict()
|
||||||
|
if dict_.get('pred') is None or dict_.get('pred_masked') or force:
|
||||||
|
pred = explainer.get_prediction(explanation)
|
||||||
|
pred_masked = explainer.get_masked_prediction(x=explanation.x,edge_index=explanation.edge_index,node_mask=explanation.node_mask,edge_mask=explanation.edge_mask)
|
||||||
|
explanation.__setattr__('pred',pred)
|
||||||
|
explanation.__setattr__('pred_masked',pred_masked)
|
||||||
|
return explanation
|
||||||
|
else:
|
||||||
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||||
|
auto_select_device()
|
||||||
|
|
||||||
|
# Load components
|
||||||
|
dataset = outline.dataset.to(cfg.accelerator)
|
||||||
|
model = outline.model.to(cfg.accelerator)
|
||||||
|
model_info = outline.model_info
|
||||||
|
metrics = outline.metrics
|
||||||
|
explaining_algorithm = outline.explaining_algorithm
|
||||||
|
attacks = outline.attacks
|
||||||
|
explainer_cfg = outline.explainer_cfg
|
||||||
|
model_signature = outline.model_signature
|
||||||
|
|
||||||
|
# Set seed
|
||||||
|
seed_everything(explaining_cfg.seed)
|
||||||
|
|
||||||
|
# Global path
|
||||||
|
global_path = os.path.join(explaining_cfg.out_dir, model_signature)
|
||||||
|
makedirs(global_path)
|
||||||
|
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
|
||||||
|
write_json(model_info, os.path.join(global_path, "info.json"))
|
||||||
|
|
||||||
|
global_path = os.path.join(global_path, explaining_cfg.explainer.name+'_'+obj_config_to_str(explaining_algorithm))
|
||||||
|
makedirs(global_path)
|
||||||
|
write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
|
||||||
|
write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
|
||||||
|
|
||||||
|
global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm))
|
||||||
|
makedirs(global_path)
|
||||||
|
explainer = Explainer(
|
||||||
|
model=model,
|
||||||
|
algorithm=explaining_algorithm,
|
||||||
|
explainer_config=dict(
|
||||||
|
explanation_type=explaining_cfg.explanation_type,
|
||||||
|
node_mask_type="object",
|
||||||
|
edge_mask_type="object",
|
||||||
|
),
|
||||||
|
model_config=dict(
|
||||||
|
mode="regression",
|
||||||
|
task_level=cfg.dataset.task,
|
||||||
|
return_type=explaining_cfg.model_config.return_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Save explaining configuration
|
||||||
|
for index, item in enumerate(dataset):
|
||||||
|
save_raw_path = os.path.join(global_path, "raw")
|
||||||
|
makedirs(save_raw_path)
|
||||||
|
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
||||||
|
|
||||||
|
if is_exists(explanation_path):
|
||||||
|
if explaining_cfg.explainer.force:
|
||||||
|
explanation = explainer(
|
||||||
|
x=item.x,
|
||||||
|
edge_index=item.edge_index,
|
||||||
|
index=item.y,
|
||||||
|
target=item.y,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
explanation = load_explanation(explanation_path)
|
||||||
|
else:
|
||||||
|
explanation = explainer(
|
||||||
|
x=item.x,
|
||||||
|
edge_index=item.edge_index,
|
||||||
|
index=item.y,
|
||||||
|
target=item.y,
|
||||||
|
)
|
||||||
|
explanation = get_pred(explanation,force=False)
|
||||||
|
save_explanation(explanation,explanation_path)
|
||||||
|
for apply_relu in [True,False]:
|
||||||
|
for apply_absolute in [True,False]:
|
||||||
|
adjust = Adjust(apply_relu=apply_relu,apply_absolute=apply_absolute)
|
||||||
|
save_raw_path = os.path.join(global_path,f'adjust-{obj_config_to_str(adjust)}')
|
||||||
|
makedirs(save_raw_path)
|
||||||
|
explanation = adjust.forward(explanation)
|
||||||
|
explanation_path = os.path.join(save_raw_path, f"{index}.json")
|
||||||
|
explanation = get_pred(explanation,force=True)
|
||||||
|
save_explanation(explanation,explanation_path)
|
||||||
|
|
||||||
|
for threshold_approach in ['hard','topk','topk_hard']:
|
||||||
|
for threshold_value in explaining_cfg.threshold_config.value:
|
||||||
|
|
||||||
|
masking_path =os.path.join(save_raw_path,f'threshold={threshold_approach}-value={value}')
|
||||||
|
exp_threshold_path = os.path.join(masking_path,f'{index}.json')
|
||||||
|
|
||||||
|
if is_exists(exp_threshold_path):
|
||||||
|
explanation = load_explanation(exp_threshold_path)
|
||||||
|
else:
|
||||||
|
threshold_conf = {'threshold_type':threshold_approach,'value':threshold_value}
|
||||||
|
explainer.threshold_config = ThresholdConfig.cast(threshold_conf)
|
||||||
|
|
||||||
|
expl = copy.copy(explanation)
|
||||||
|
exp_threshold = explainer._post_process(expl)
|
||||||
|
exp_threshold= get_pred(exp_threshold,force=True)
|
||||||
|
|
||||||
|
save_explanation(exp_threshold,exp_threshold_path)
|
||||||
|
for metric in metrics:
|
||||||
|
metric_path =os.path.join(masking_path,f'{obj_config_to_str(metric)}')
|
||||||
|
if is_exists(os.path.join(metric_path,f'{index}.json')):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
out = metric.forward(exp_threshold)
|
||||||
|
write_json({f'{metric.name}':out})
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue