diff --git a/explaining_framework/config/explaining_config.py b/explaining_framework/config/explaining_config.py index 9641c25..cd68c42 100644 --- a/explaining_framework/config/explaining_config.py +++ b/explaining_framework/config/explaining_config.py @@ -57,7 +57,7 @@ def set_cfg(explaining_cfg): explaining_cfg.dataset.name = "Cora" - explaining_cfg.dataset.item = None + explaining_cfg.dataset.item = [] # ----------------------------------------------------------------------- # # Model options diff --git a/explaining_framework/utils/config_gen.py b/explaining_framework/utils/config_gen.py deleted file mode 100644 index fc5367d..0000000 --- a/explaining_framework/utils/config_gen.py +++ /dev/null @@ -1,150 +0,0 @@ -import glob -import os -import shutil - -from explaining_framework.utils.io import read_yaml, write_yaml -from torch_geometric.data.makedirs import makedirs -from torch_geometric.graphgym.loader import create_dataset -from torch_geometric.graphgym.utils.io import string_to_python - -if "__main__" == __name__: - config_folder = os.path.abspath( - os.path.join(os.path.dirname(__name__), "../../", "configs") - ) - makedirs(config_folder) - explaining_folder = os.path.join(config_folder, "explaining") - makedirs(explaining_folder) - explainer_folder = os.path.join(config_folder, "explaining") - makedirs(explainer_folder) - - DATASET = [ - "CIFAR10", - # "TRIANGLES", - # "COLORS-3", - # "REDDIT-BINARY", - # "REDDIT-MULTI-5K", - # "REDDIT-MULTI-12K", - # "COLLAB", - # "DBLP_v1", - # "COIL-DEL", - # "COIL-RAG", - # "Fingerprint", - # "Letter-high", - # "Letter-low", - # "Letter-med", - "MSRC_9", - # "MSRC_21", - "MSRC_21C", - # "DD", - # "ENZYMES", - "PROTEINS", - # "QM9", - # "MUTAG", - # "Mutagenicity", - # "AIDS", - # "PATTERN", - # "CLUSTER", - "MNIST", - "CIFAR10", - # "TSP", - # "CSL", - # "KarateClub", - # "CS", - # "Physics", - # "BBBP", - # "Tox21", - # "HIV", - # "PCBA", - # "MUV", - # "BACE", - # "SIDER", - # "ClinTox", - # "AIFB", - # "AM", - # "MUTAG", - # "BGS", - # "FAUST", - # "DynamicFAUST", - # "ShapeNet", - # "ModelNet10", - # "ModelNet40", - # "PascalVOC-SP", - # "COCO-SP", - ] - EXPLAINER = [ - "CAM", - "GradCAM", - "GNN_LRP", - "GradExplainer", - "GuidedBackPropagation", - "IntegratedGradients", - # "PGExplainer", - "PGMExplainer", - "RandomExplainer", - # "SubgraphX", - # "GraphMASK", - "GNNExplainer", - "EIXGNN", - "SCGNN", - ] - - for dataset_name in DATASET: - for model_kind in ["best", "worst"]: - for explainer_name in EXPLAINER: - explaining_cfg = {} - # explaining_cfg['adjust']['strategy']= 'rpns' - # explaining_cfg['attack']['name']= 'all' - explaining_cfg["cfg_dest"] = string_to_python( - f"dataset={dataset_name}-model={model_kind}-explainer={explainer_name}.yaml" - ) - # = f"dataset={dataset_name}-model={model_kind}=explainer={explainer_name}-chunk=[{chunk[0]},{chunk[-1]}]" - - explaining_cfg["dataset"] = {} - explaining_cfg["dataset"]["name"] = string_to_python(dataset_name) - explaining_cfg["dataset"]["item"] = [3, 45, 78, 23] - # explaining_cfg['explainer']['cfg']= 'default' - explaining_cfg["explainer"] = {} - explaining_cfg["explainer"]["name"] = string_to_python(explainer_name) - explaining_cfg["explainer"]["force"] = True - explaining_cfg["explanation_type"] = string_to_python("phenomenon") - # explaining_cfg['metrics']['accuracy']['name']='all' - # explaining_cfg['metrics']['fidelity']['name']='all' - # explaining_cfg['metrics']['sparsity']['name']='all' - explaining_cfg["model"] = {} - explaining_cfg["model"]["ckpt"] = string_to_python(model_kind) - explaining_cfg["model"]["path"] = string_to_python( - # "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" - "/home/SIC/araison/exps/pytorch_geometric/graphgym/results/" - # "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" - ) - # explaining_cfg['out_dir']='./explanation' - # explaining_cfg['print']='both' - # explaining_cfg['threshold']['config']['type']='all' - # explaining_cfg['threshold']['value']['hard']=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] - # explaining_cfg['threshold']['value']['topk']=[2, 3, 5, 10, 20, 30, 50] - PATH = os.path.join( - explaining_folder + "/" + explaining_cfg["cfg_dest"], - ) - write_yaml(explaining_cfg, PATH) - # if os.path.exists(PATH): - # continue - # else: - # write_yaml(explaining_cfg, PATH) - # configs = [ - # path for path in glob.glob(os.path.join(explaining_folder, "**", "*.yaml")) - # ] - # for path in configs: - # data = read_yaml(path) - # data["model"][ - # "path" - # ] = "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" - # write_yaml(data, path) - - # for index, config_chunk in enumerate( - # chunkizing_list(configs, int(len(configs) / 5)) - # ): - # PATH_ = os.path.join(explaining_folder, f"gpu={index}") - # makedirs(PATH_) - # for path in config_chunk: - # filename = os.path.basename(path) - # shutil.copy2(path, os.path.join(PATH_, filename)) diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index d1ba6a3..ea364d3 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -6,6 +6,20 @@ import os from typing import Any from eixgnn.eixgnn import EiXGNN +from scgnn.scgnn import SCGNN +from torch_geometric import seed_everything +from torch_geometric.data import Batch, Data +from torch_geometric.data.makedirs import makedirs +from torch_geometric.explain import Explainer +from torch_geometric.explain.config import ThresholdConfig +from torch_geometric.explain.explanation import Explanation +from torch_geometric.graphgym.config import cfg +from torch_geometric.graphgym.loader import create_dataset, create_dataset2 +from torch_geometric.graphgym.model_builder import cfg, create_model +from torch_geometric.graphgym.utils.device import auto_select_device +from torch_geometric.loader.dataloader import DataLoader +from yacs.config import CfgNode as CN + from explaining_framework.config.explainer_config.eixgnn_config import \ eixgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg @@ -31,19 +45,6 @@ from explaining_framework.utils.io import (dump_cfg, is_exists, obj_config_to_str, read_json, set_printing, write_json, write_yaml) -from scgnn.scgnn import SCGNN -from torch_geometric import seed_everything -from torch_geometric.data import Batch, Data -from torch_geometric.data.makedirs import makedirs -from torch_geometric.explain import Explainer -from torch_geometric.explain.config import ThresholdConfig -from torch_geometric.explain.explanation import Explanation -from torch_geometric.graphgym.config import cfg -from torch_geometric.graphgym.loader import create_dataset -from torch_geometric.graphgym.model_builder import cfg, create_model -from torch_geometric.graphgym.utils.device import auto_select_device -from torch_geometric.loader.dataloader import DataLoader -from yacs.config import CfgNode as CN all__captum = [ "LRP", @@ -155,10 +156,9 @@ class ExplainingOutline(object): self.load_explainer_cfg() self.load_explaining_algorithm() self.load_explainer() + # self.load_dataset_to_dataloader() self.load_metric() self.load_attack() - self.load_dataset_to_dataloader() - self.load_indexes() self.load_adjust() self.load_threshold() self.load_graphstat() @@ -171,38 +171,16 @@ class ExplainingOutline(object): device = self.cfg.accelerator self.model = self.model.to(device) - def get_data(self): - if self.dataset is None: - self.load_dataset() - try: - item = next(self.dataset) - device = self.cfg.accelerator - item = item.to(device) - return item - except StopIteration: - return None - - def load_indexes(self): - item = self.explaining_cfg.dataset.item - if isinstance(item, (list, int)): - indexes = item - else: - indexes = list(range(len(self.dataset))) - self.indexes = iter(indexes) - - def get_index(self): - if self.indexes is None: - self.load_indexes() - try: - item = next(self.indexes) - return item - except StopIteration: - return None - - def get_item(self): - item = self.get_data() - index = self.get_index() - return item, index + # def get_data(self): + # if self.dataset is None: + # self.load_dataset() + # try: + # item = next(self.dataset) + # device = self.cfg.accelerator + # item = item.to(device) + # return item + # except StopIteration: + # return None def load_model_info(self): info = LoadModelInfo( @@ -270,26 +248,19 @@ class ExplainingOutline(object): raise ValueError( f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model." ) - self.dataset = create_dataset() + self.dataset = create_dataset2() + item = self.explaining_cfg.dataset.item - if isinstance(item, int): - self.dataset = self.dataset[item : item + 1] - elif isinstance(item, list): - self.dataset = self.dataset[item] + if isinstance(item, (list)): + if len(item) == 0: + self.indexes = list(range(len(self.dataset))) + else: + self.indexes = item - def load_dataset_to_dataloader(self, to_iter=True): + self.dataset = self.dataset[self.indexes] + + def load_dataset_to_dataloader(self, to_iter=False): self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1) - if to_iter: - self.dataset = iter(self.dataset) - - def reload_dataset(self): - self.load_dataset() - self.load_indexes() - - def reload_dataloader(self): - self.load_dataset() - self.load_dataset_to_dataloader() - self.load_indexes() def load_explaining_algorithm(self): self.load_explainer_cfg() diff --git a/main.py b/main.py index 00957d1..33c6a56 100644 --- a/main.py +++ b/main.py @@ -26,12 +26,11 @@ from explaining_framework.utils.io import (dump_cfg, is_exists, if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) - - pbar = tqdm(total=len(outline.dataset) * len(outline.attacks)) - - for item, index in zip(outline.dataset, outline.indexes): - item = item.to(outline.cfg.accelerator) - for attack in outline.attacks: + for attack in outline.attacks: + for item, index in tqdm( + zip(outline.dataset, outline.indexes), total=len(outline.dataset) + ): + item = item.to(outline.cfg.accelerator) attack_path = os.path.join( outline.out_dir, attack.__class__.__name__, obj_config_to_str(attack) ) @@ -40,13 +39,12 @@ if __name__ == "__main__": data_attack = outline.get_attack( attack=attack, item=item, path=data_attack_path ) - if data_attack is None: - continue - outline.reload_dataloader() - for item, index in zip(outline.dataset, outline.indexes): - item = item.to(outline.cfg.accelerator) - for attack in outline.attacks: + for attack in outline.attacks: + for item, index in tqdm( + zip(outline.dataset, outline.indexes), total=len(outline.dataset) + ): + item = item.to(outline.cfg.accelerator) attack_path_ = os.path.join( outline.explainer_path, attack.__class__.__name__, @@ -60,7 +58,6 @@ if __name__ == "__main__": if attack_data is None: continue exp = outline.get_explanation(item=attack_data, path=data_attack_path_) - pbar.update(1) if exp is None: continue else: @@ -103,5 +100,3 @@ if __name__ == "__main__": ) with open(os.path.join(outline.out_dir, "done"), "w") as f: f.write("") - - pbar.close()