Fixings somes bugs, and adding new features
This commit is contained in:
parent
db04fbfaeb
commit
3372f81576
|
@ -57,7 +57,7 @@ def set_cfg(explaining_cfg):
|
||||||
|
|
||||||
explaining_cfg.dataset.name = "Cora"
|
explaining_cfg.dataset.name = "Cora"
|
||||||
|
|
||||||
explaining_cfg.dataset.item = None
|
explaining_cfg.dataset.item = []
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
# ----------------------------------------------------------------------- #
|
||||||
# Model options
|
# Model options
|
||||||
|
|
|
@ -1,150 +0,0 @@
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
from explaining_framework.utils.io import read_yaml, write_yaml
|
|
||||||
from torch_geometric.data.makedirs import makedirs
|
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
|
||||||
from torch_geometric.graphgym.utils.io import string_to_python
|
|
||||||
|
|
||||||
if "__main__" == __name__:
|
|
||||||
config_folder = os.path.abspath(
|
|
||||||
os.path.join(os.path.dirname(__name__), "../../", "configs")
|
|
||||||
)
|
|
||||||
makedirs(config_folder)
|
|
||||||
explaining_folder = os.path.join(config_folder, "explaining")
|
|
||||||
makedirs(explaining_folder)
|
|
||||||
explainer_folder = os.path.join(config_folder, "explaining")
|
|
||||||
makedirs(explainer_folder)
|
|
||||||
|
|
||||||
DATASET = [
|
|
||||||
"CIFAR10",
|
|
||||||
# "TRIANGLES",
|
|
||||||
# "COLORS-3",
|
|
||||||
# "REDDIT-BINARY",
|
|
||||||
# "REDDIT-MULTI-5K",
|
|
||||||
# "REDDIT-MULTI-12K",
|
|
||||||
# "COLLAB",
|
|
||||||
# "DBLP_v1",
|
|
||||||
# "COIL-DEL",
|
|
||||||
# "COIL-RAG",
|
|
||||||
# "Fingerprint",
|
|
||||||
# "Letter-high",
|
|
||||||
# "Letter-low",
|
|
||||||
# "Letter-med",
|
|
||||||
"MSRC_9",
|
|
||||||
# "MSRC_21",
|
|
||||||
"MSRC_21C",
|
|
||||||
# "DD",
|
|
||||||
# "ENZYMES",
|
|
||||||
"PROTEINS",
|
|
||||||
# "QM9",
|
|
||||||
# "MUTAG",
|
|
||||||
# "Mutagenicity",
|
|
||||||
# "AIDS",
|
|
||||||
# "PATTERN",
|
|
||||||
# "CLUSTER",
|
|
||||||
"MNIST",
|
|
||||||
"CIFAR10",
|
|
||||||
# "TSP",
|
|
||||||
# "CSL",
|
|
||||||
# "KarateClub",
|
|
||||||
# "CS",
|
|
||||||
# "Physics",
|
|
||||||
# "BBBP",
|
|
||||||
# "Tox21",
|
|
||||||
# "HIV",
|
|
||||||
# "PCBA",
|
|
||||||
# "MUV",
|
|
||||||
# "BACE",
|
|
||||||
# "SIDER",
|
|
||||||
# "ClinTox",
|
|
||||||
# "AIFB",
|
|
||||||
# "AM",
|
|
||||||
# "MUTAG",
|
|
||||||
# "BGS",
|
|
||||||
# "FAUST",
|
|
||||||
# "DynamicFAUST",
|
|
||||||
# "ShapeNet",
|
|
||||||
# "ModelNet10",
|
|
||||||
# "ModelNet40",
|
|
||||||
# "PascalVOC-SP",
|
|
||||||
# "COCO-SP",
|
|
||||||
]
|
|
||||||
EXPLAINER = [
|
|
||||||
"CAM",
|
|
||||||
"GradCAM",
|
|
||||||
"GNN_LRP",
|
|
||||||
"GradExplainer",
|
|
||||||
"GuidedBackPropagation",
|
|
||||||
"IntegratedGradients",
|
|
||||||
# "PGExplainer",
|
|
||||||
"PGMExplainer",
|
|
||||||
"RandomExplainer",
|
|
||||||
# "SubgraphX",
|
|
||||||
# "GraphMASK",
|
|
||||||
"GNNExplainer",
|
|
||||||
"EIXGNN",
|
|
||||||
"SCGNN",
|
|
||||||
]
|
|
||||||
|
|
||||||
for dataset_name in DATASET:
|
|
||||||
for model_kind in ["best", "worst"]:
|
|
||||||
for explainer_name in EXPLAINER:
|
|
||||||
explaining_cfg = {}
|
|
||||||
# explaining_cfg['adjust']['strategy']= 'rpns'
|
|
||||||
# explaining_cfg['attack']['name']= 'all'
|
|
||||||
explaining_cfg["cfg_dest"] = string_to_python(
|
|
||||||
f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml"
|
|
||||||
)
|
|
||||||
# = f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}-chunk=[{chunk[0]},{chunk[-1]}]"
|
|
||||||
|
|
||||||
explaining_cfg["dataset"] = {}
|
|
||||||
explaining_cfg["dataset"]["name"] = string_to_python(dataset_name)
|
|
||||||
explaining_cfg["dataset"]["item"] = [3, 45, 78, 23]
|
|
||||||
# explaining_cfg['explainer']['cfg']= 'default'
|
|
||||||
explaining_cfg["explainer"] = {}
|
|
||||||
explaining_cfg["explainer"]["name"] = string_to_python(explainer_name)
|
|
||||||
explaining_cfg["explainer"]["force"] = True
|
|
||||||
explaining_cfg["explanation_type"] = string_to_python("phenomenon")
|
|
||||||
# explaining_cfg['metrics']['accuracy']['name']='all'
|
|
||||||
# explaining_cfg['metrics']['fidelity']['name']='all'
|
|
||||||
# explaining_cfg['metrics']['sparsity']['name']='all'
|
|
||||||
explaining_cfg["model"] = {}
|
|
||||||
explaining_cfg["model"]["ckpt"] = string_to_python(model_kind)
|
|
||||||
explaining_cfg["model"]["path"] = string_to_python(
|
|
||||||
# "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
|
||||||
"/home/SIC/araison/exps/pytorch_geometric/graphgym/results/"
|
|
||||||
# "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
|
||||||
)
|
|
||||||
# explaining_cfg['out_dir']='./explanation'
|
|
||||||
# explaining_cfg['print']='both'
|
|
||||||
# explaining_cfg['threshold']['config']['type']='all'
|
|
||||||
# explaining_cfg['threshold']['value']['hard']=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
|
|
||||||
# explaining_cfg['threshold']['value']['topk']=[2, 3, 5, 10, 20, 30, 50]
|
|
||||||
PATH = os.path.join(
|
|
||||||
explaining_folder + "/" + explaining_cfg["cfg_dest"],
|
|
||||||
)
|
|
||||||
write_yaml(explaining_cfg, PATH)
|
|
||||||
# if os.path.exists(PATH):
|
|
||||||
# continue
|
|
||||||
# else:
|
|
||||||
# write_yaml(explaining_cfg, PATH)
|
|
||||||
# configs = [
|
|
||||||
# path for path in glob.glob(os.path.join(explaining_folder, "**", "*.yaml"))
|
|
||||||
# ]
|
|
||||||
# for path in configs:
|
|
||||||
# data = read_yaml(path)
|
|
||||||
# data["model"][
|
|
||||||
# "path"
|
|
||||||
# ] = "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
|
||||||
# write_yaml(data, path)
|
|
||||||
|
|
||||||
# for index, config_chunk in enumerate(
|
|
||||||
# chunkizing_list(configs, int(len(configs) / 5))
|
|
||||||
# ):
|
|
||||||
# PATH_ = os.path.join(explaining_folder, f"gpu={index}")
|
|
||||||
# makedirs(PATH_)
|
|
||||||
# for path in config_chunk:
|
|
||||||
# filename = os.path.basename(path)
|
|
||||||
# shutil.copy2(path, os.path.join(PATH_, filename))
|
|
|
@ -6,6 +6,20 @@ import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
|
from scgnn.scgnn import SCGNN
|
||||||
|
from torch_geometric import seed_everything
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.data.makedirs import makedirs
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.loader import create_dataset, create_dataset2
|
||||||
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
from torch_geometric.loader.dataloader import DataLoader
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||||
|
@ -31,19 +45,6 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
obj_config_to_str, read_json,
|
obj_config_to_str, read_json,
|
||||||
set_printing, write_json,
|
set_printing, write_json,
|
||||||
write_yaml)
|
write_yaml)
|
||||||
from scgnn.scgnn import SCGNN
|
|
||||||
from torch_geometric import seed_everything
|
|
||||||
from torch_geometric.data import Batch, Data
|
|
||||||
from torch_geometric.data.makedirs import makedirs
|
|
||||||
from torch_geometric.explain import Explainer
|
|
||||||
from torch_geometric.explain.config import ThresholdConfig
|
|
||||||
from torch_geometric.explain.explanation import Explanation
|
|
||||||
from torch_geometric.graphgym.config import cfg
|
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
||||||
from torch_geometric.loader.dataloader import DataLoader
|
|
||||||
from yacs.config import CfgNode as CN
|
|
||||||
|
|
||||||
all__captum = [
|
all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -155,10 +156,9 @@ class ExplainingOutline(object):
|
||||||
self.load_explainer_cfg()
|
self.load_explainer_cfg()
|
||||||
self.load_explaining_algorithm()
|
self.load_explaining_algorithm()
|
||||||
self.load_explainer()
|
self.load_explainer()
|
||||||
|
# self.load_dataset_to_dataloader()
|
||||||
self.load_metric()
|
self.load_metric()
|
||||||
self.load_attack()
|
self.load_attack()
|
||||||
self.load_dataset_to_dataloader()
|
|
||||||
self.load_indexes()
|
|
||||||
self.load_adjust()
|
self.load_adjust()
|
||||||
self.load_threshold()
|
self.load_threshold()
|
||||||
self.load_graphstat()
|
self.load_graphstat()
|
||||||
|
@ -171,38 +171,16 @@ class ExplainingOutline(object):
|
||||||
device = self.cfg.accelerator
|
device = self.cfg.accelerator
|
||||||
self.model = self.model.to(device)
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def get_data(self):
|
# def get_data(self):
|
||||||
if self.dataset is None:
|
# if self.dataset is None:
|
||||||
self.load_dataset()
|
# self.load_dataset()
|
||||||
try:
|
# try:
|
||||||
item = next(self.dataset)
|
# item = next(self.dataset)
|
||||||
device = self.cfg.accelerator
|
# device = self.cfg.accelerator
|
||||||
item = item.to(device)
|
# item = item.to(device)
|
||||||
return item
|
# return item
|
||||||
except StopIteration:
|
# except StopIteration:
|
||||||
return None
|
# return None
|
||||||
|
|
||||||
def load_indexes(self):
|
|
||||||
item = self.explaining_cfg.dataset.item
|
|
||||||
if isinstance(item, (list, int)):
|
|
||||||
indexes = item
|
|
||||||
else:
|
|
||||||
indexes = list(range(len(self.dataset)))
|
|
||||||
self.indexes = iter(indexes)
|
|
||||||
|
|
||||||
def get_index(self):
|
|
||||||
if self.indexes is None:
|
|
||||||
self.load_indexes()
|
|
||||||
try:
|
|
||||||
item = next(self.indexes)
|
|
||||||
return item
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_item(self):
|
|
||||||
item = self.get_data()
|
|
||||||
index = self.get_index()
|
|
||||||
return item, index
|
|
||||||
|
|
||||||
def load_model_info(self):
|
def load_model_info(self):
|
||||||
info = LoadModelInfo(
|
info = LoadModelInfo(
|
||||||
|
@ -270,26 +248,19 @@ class ExplainingOutline(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
||||||
)
|
)
|
||||||
self.dataset = create_dataset()
|
self.dataset = create_dataset2()
|
||||||
|
|
||||||
item = self.explaining_cfg.dataset.item
|
item = self.explaining_cfg.dataset.item
|
||||||
if isinstance(item, int):
|
if isinstance(item, (list)):
|
||||||
self.dataset = self.dataset[item : item + 1]
|
if len(item) == 0:
|
||||||
elif isinstance(item, list):
|
self.indexes = list(range(len(self.dataset)))
|
||||||
self.dataset = self.dataset[item]
|
else:
|
||||||
|
self.indexes = item
|
||||||
|
|
||||||
def load_dataset_to_dataloader(self, to_iter=True):
|
self.dataset = self.dataset[self.indexes]
|
||||||
|
|
||||||
|
def load_dataset_to_dataloader(self, to_iter=False):
|
||||||
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
||||||
if to_iter:
|
|
||||||
self.dataset = iter(self.dataset)
|
|
||||||
|
|
||||||
def reload_dataset(self):
|
|
||||||
self.load_dataset()
|
|
||||||
self.load_indexes()
|
|
||||||
|
|
||||||
def reload_dataloader(self):
|
|
||||||
self.load_dataset()
|
|
||||||
self.load_dataset_to_dataloader()
|
|
||||||
self.load_indexes()
|
|
||||||
|
|
||||||
def load_explaining_algorithm(self):
|
def load_explaining_algorithm(self):
|
||||||
self.load_explainer_cfg()
|
self.load_explainer_cfg()
|
||||||
|
|
21
main.py
21
main.py
|
@ -26,12 +26,11 @@ from explaining_framework.utils.io import (dump_cfg, is_exists,
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||||
|
|
||||||
pbar = tqdm(total=len(outline.dataset) * len(outline.attacks))
|
|
||||||
|
|
||||||
for item, index in zip(outline.dataset, outline.indexes):
|
|
||||||
item = item.to(outline.cfg.accelerator)
|
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
|
for item, index in tqdm(
|
||||||
|
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||||
|
):
|
||||||
|
item = item.to(outline.cfg.accelerator)
|
||||||
attack_path = os.path.join(
|
attack_path = os.path.join(
|
||||||
outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack)
|
outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack)
|
||||||
)
|
)
|
||||||
|
@ -40,13 +39,12 @@ if __name__ == "__main__":
|
||||||
data_attack = outline.get_attack(
|
data_attack = outline.get_attack(
|
||||||
attack=attack, item=item, path=data_attack_path
|
attack=attack, item=item, path=data_attack_path
|
||||||
)
|
)
|
||||||
if data_attack is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
outline.reload_dataloader()
|
|
||||||
for item, index in zip(outline.dataset, outline.indexes):
|
|
||||||
item = item.to(outline.cfg.accelerator)
|
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
|
for item, index in tqdm(
|
||||||
|
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||||
|
):
|
||||||
|
item = item.to(outline.cfg.accelerator)
|
||||||
attack_path_ = os.path.join(
|
attack_path_ = os.path.join(
|
||||||
outline.explainer_path,
|
outline.explainer_path,
|
||||||
attack.__class__.__name__,
|
attack.__class__.__name__,
|
||||||
|
@ -60,7 +58,6 @@ if __name__ == "__main__":
|
||||||
if attack_data is None:
|
if attack_data is None:
|
||||||
continue
|
continue
|
||||||
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
||||||
pbar.update(1)
|
|
||||||
if exp is None:
|
if exp is None:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
@ -103,5 +100,3 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
with open(os.path.join(outline.out_dir, "done"), "w") as f:
|
with open(os.path.join(outline.out_dir, "done"), "w") as f:
|
||||||
f.write("")
|
f.write("")
|
||||||
|
|
||||||
pbar.close()
|
|
||||||
|
|
Loading…
Reference in New Issue