Adding new features and first draft of main.py
This commit is contained in:
		
							parent
							
								
									a00e73d4f0
								
							
						
					
					
						commit
						074ff25c83
					
				
					 5 changed files with 203 additions and 17 deletions
				
			
		|  | @ -49,11 +49,4 @@ class Metric(ABC): | ||||||
| 
 | 
 | ||||||
|         return out |         return out | ||||||
| 
 | 
 | ||||||
|     def save_config(self, path) -> None: | 
 | ||||||
|         config = {k: getattr(self, k) for k in dir(self)} |  | ||||||
|         config = { |  | ||||||
|             k: v |  | ||||||
|             for k, v in config.items() |  | ||||||
|             if isinstance(v, (int, float, str, bool)) or v is None |  | ||||||
|         } |  | ||||||
|         write_json(config, path) |  | ||||||
|  |  | ||||||
|  | @ -7,11 +7,12 @@ import logging | ||||||
| import os | import os | ||||||
| 
 | 
 | ||||||
| import torch | import torch | ||||||
| from explaining_framework.utils.io import read_yaml |  | ||||||
| from torch_geometric.graphgym.model_builder import create_model | from torch_geometric.graphgym.model_builder import create_model | ||||||
| from torch_geometric.graphgym.train import GraphGymDataModule | from torch_geometric.graphgym.train import GraphGymDataModule | ||||||
| from torch_geometric.graphgym.utils.io import json_to_dict_list | from torch_geometric.graphgym.utils.io import json_to_dict_list | ||||||
| 
 | 
 | ||||||
|  | from explaining_framework.utils.io import read_yaml | ||||||
|  | 
 | ||||||
| MODEL_STATE = "model_state" | MODEL_STATE = "model_state" | ||||||
| OPTIMIZER_STATE = "optimizer_state" | OPTIMIZER_STATE = "optimizer_state" | ||||||
| SCHEDULER_STATE = "scheduler_state" | SCHEDULER_STATE = "scheduler_state" | ||||||
|  | @ -44,14 +45,15 @@ class LoadModelInfo(object): | ||||||
| 
 | 
 | ||||||
|     def list_stats(self, path) -> list: |     def list_stats(self, path) -> list: | ||||||
|         info = [] |         info = [] | ||||||
|         for path in glob.glob( |         for path_ in glob.glob( | ||||||
|             os.path.join(path, "[0-9]", self.wrt_metric, "stats.json") |             os.path.join(path, "[0-9]", self.wrt_metric, "stats.json") | ||||||
|         ): |         ): | ||||||
|             stats = json_to_dict_list(path) |             stats = json_to_dict_list(path_) | ||||||
|             for stat in stats: |             for stat in stats: | ||||||
|                 xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path))) |                 xp_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(path_))) | ||||||
|  |                 seed = int(os.path.basename(os.path.dirname(os.path.dirname(path_)))) | ||||||
|                 ckpt_dir_path = os.path.join( |                 ckpt_dir_path = os.path.join( | ||||||
|                     os.path.dirname(os.path.dirname(path)), "ckpt" |                     os.path.dirname(os.path.dirname(path_)), "ckpt" | ||||||
|                 ) |                 ) | ||||||
|                 cfg_path = os.path.join(xp_dir_path, "config.yaml") |                 cfg_path = os.path.join(xp_dir_path, "config.yaml") | ||||||
|                 epoch = stat["epoch"] |                 epoch = stat["epoch"] | ||||||
|  | @ -68,12 +70,16 @@ class LoadModelInfo(object): | ||||||
|                             epoch=epoch, ckpt_dir_path=ckpt_dir_path |                             epoch=epoch, ckpt_dir_path=ckpt_dir_path | ||||||
|                         ), |                         ), | ||||||
|                         "cfg_path": cfg_path, |                         "cfg_path": cfg_path, | ||||||
|  |                         "seed": seed, | ||||||
|                         "epoch": epoch, |                         "epoch": epoch, | ||||||
|                         "accuracy": accuracy, |                         "accuracy": accuracy, | ||||||
|                         "loss": loss, |                         "loss": loss, | ||||||
|                         "lr": lr, |                         "lr": lr, | ||||||
|                         "params": params, |                         "params": params, | ||||||
|                         "time_iter": time_iter, |                         "time_iter": time_iter, | ||||||
|  |                         "which": self.which | ||||||
|  |                         if self.which in ["best", "worst"] | ||||||
|  |                         else None, | ||||||
|                     } |                     } | ||||||
|                 ) |                 ) | ||||||
|         return info |         return info | ||||||
|  | @ -112,6 +118,22 @@ class LoadModelInfo(object): | ||||||
|             self.info = [item for item in stats if item["ckpt_path"] == self.which][0] |             self.info = [item for item in stats if item["ckpt_path"] == self.which][0] | ||||||
|         return self.info |         return self.info | ||||||
| 
 | 
 | ||||||
|  |     def get_model_signature(self): | ||||||
|  |         if self.info is None: | ||||||
|  |             self.set_info() | ||||||
|  | 
 | ||||||
|  |         model_name = os.path.basename(self.info["xp_dir_name"]) | ||||||
|  |         model_seed = self.info["seed"] | ||||||
|  |         epoch = os.path.basename(self.info["ckpt_path"]) | ||||||
|  |         model_signature = "-".join( | ||||||
|  |             [ | ||||||
|  |                 f"{name}={val}" | ||||||
|  |                 for name, val in zip(["name", "seed"], [model_name, model_seed]) | ||||||
|  |             ] | ||||||
|  |             + [epoch] | ||||||
|  |         ) | ||||||
|  |         return model_signature | ||||||
|  | 
 | ||||||
|     def get_ckpt_path(self, epoch: int, ckpt_dir_path: str): |     def get_ckpt_path(self, epoch: int, ckpt_dir_path: str): | ||||||
|         paths = os.path.join(ckpt_dir_path, "*.ckpt") |         paths = os.path.join(ckpt_dir_path, "*.ckpt") | ||||||
|         ckpts = [] |         ckpts = [] | ||||||
|  |  | ||||||
|  | @ -4,6 +4,7 @@ from typing import Any | ||||||
| from eixgnn.eixgnn import EiXGNN | from eixgnn.eixgnn import EiXGNN | ||||||
| from scgnn.scgnn import SCGNN | from scgnn.scgnn import SCGNN | ||||||
| from torch_geometric.data import Batch, Data | from torch_geometric.data import Batch, Data | ||||||
|  | from torch_geometric.data.loader.dataloader import DataLoader | ||||||
| from torch_geometric.explain import Explainer | from torch_geometric.explain import Explainer | ||||||
| from torch_geometric.graphgym.config import cfg | from torch_geometric.graphgym.config import cfg | ||||||
| from torch_geometric.graphgym.loader import create_dataset | from torch_geometric.graphgym.loader import create_dataset | ||||||
|  | @ -62,6 +63,8 @@ all_fidelity = [ | ||||||
|     "fidelity_plus_prob", |     "fidelity_plus_prob", | ||||||
|     "fidelity_minus_prob", |     "fidelity_minus_prob", | ||||||
|     "infidelity_KL", |     "infidelity_KL", | ||||||
|  |     "characterization", | ||||||
|  |     "characterization_prob", | ||||||
| ] | ] | ||||||
| all_accuracy = [ | all_accuracy = [ | ||||||
|     "precision_score", |     "precision_score", | ||||||
|  | @ -94,6 +97,7 @@ class ExplainingOutline(object): | ||||||
|         self.model_info = None |         self.model_info = None | ||||||
|         self.metrics = None |         self.metrics = None | ||||||
|         self.attacks = None |         self.attacks = None | ||||||
|  |         self.model_signature = None | ||||||
| 
 | 
 | ||||||
|         self.load_explaining_cfg() |         self.load_explaining_cfg() | ||||||
|         self.load_model_info() |         self.load_model_info() | ||||||
|  | @ -112,6 +116,7 @@ class ExplainingOutline(object): | ||||||
|             which=self.explaining_cfg.model.ckpt, |             which=self.explaining_cfg.model.ckpt, | ||||||
|         ) |         ) | ||||||
|         self.model_info = info.set_info() |         self.model_info = info.set_info() | ||||||
|  |         self.model_signature = info.get_model_signature() | ||||||
| 
 | 
 | ||||||
|     def load_cfg(self): |     def load_cfg(self): | ||||||
|         cfg.set_new_allowed(True) |         cfg.set_new_allowed(True) | ||||||
|  | @ -166,6 +171,7 @@ class ExplainingOutline(object): | ||||||
|         if isinstance(self.explaining_cfg.dataset.specific_items, int): |         if isinstance(self.explaining_cfg.dataset.specific_items, int): | ||||||
|             ind = self.explaining_cfg.dataset.specific_items |             ind = self.explaining_cfg.dataset.specific_items | ||||||
|             self.dataset = self.dataset[ind : ind + 1] |             self.dataset = self.dataset[ind : ind + 1] | ||||||
|  |         self.dataset = DataLoader(dataset=dataset, shuffle=False, batch_size=1) | ||||||
| 
 | 
 | ||||||
|     def load_explainer(self): |     def load_explainer(self): | ||||||
|         self.load_explainer_cfg() |         self.load_explainer_cfg() | ||||||
|  | @ -199,6 +205,10 @@ class ExplainingOutline(object): | ||||||
|                     interest_map_norm=self.explainer_cfg.interest_map_norm, |                     interest_map_norm=self.explainer_cfg.interest_map_norm, | ||||||
|                     score_map_norm=self.explainer_cfg.score_map_norm, |                     score_map_norm=self.explainer_cfg.score_map_norm, | ||||||
|                 ) |                 ) | ||||||
|  |         elif name is None: | ||||||
|  |             explaining_algorithm = None | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"{name_} Metric is not supported yet") | ||||||
|         self.explaining_algorithm = explaining_algorithm |         self.explaining_algorithm = explaining_algorithm | ||||||
| 
 | 
 | ||||||
|     def load_metric(self): |     def load_metric(self): | ||||||
|  | @ -228,6 +238,10 @@ class ExplainingOutline(object): | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|                     f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation" |                     f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation" | ||||||
|                 ) |                 ) | ||||||
|  |         elif name_ is None: | ||||||
|  |             self.metrics = [] | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"{name_} Metric is not supported yet") | ||||||
| 
 | 
 | ||||||
|     def load_attack(self): |     def load_attack(self): | ||||||
|         if self.cfg is None: |         if self.cfg is None: | ||||||
|  | @ -238,5 +252,9 @@ class ExplainingOutline(object): | ||||||
|         if name_ == "all": |         if name_ == "all": | ||||||
|             all_rob_metrics = [Attack(name) for name in all_robust] |             all_rob_metrics = [Attack(name) for name in all_robust] | ||||||
|             self.attacks = all_rob_metrics |             self.attacks = all_rob_metrics | ||||||
|         if name_ in all_robust: |         elif name_ in all_robust: | ||||||
|             self.attacks = [Attack(name_)] |             self.attacks = [Attack(name_)] | ||||||
|  |         elif name_ is None: | ||||||
|  |             slef.attacks = [] | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"{name_} is an Attack method that is not supported yet") | ||||||
|  |  | ||||||
|  | @ -1,5 +1,6 @@ | ||||||
| import json | import json | ||||||
| import os | import os | ||||||
|  | 
 | ||||||
| import yaml | import yaml | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -24,6 +25,26 @@ def write_yaml(data: dict, path: str) -> None: | ||||||
|     with open(path, "w") as f: |     with open(path, "w") as f: | ||||||
|         data = yaml.dump(data, f) |         data = yaml.dump(data, f) | ||||||
| 
 | 
 | ||||||
| def is_exists(path:str)-> bool: |  | ||||||
|     return os.path.exists(path)  |  | ||||||
| 
 | 
 | ||||||
|  | def is_exists(path: str) -> bool: | ||||||
|  |     return os.path.exists(path) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_obj_config(obj): | ||||||
|  |     config = {k: getattr(obj, k) for k in dir(obj)} | ||||||
|  |     config = { | ||||||
|  |         k: v | ||||||
|  |         for k, v in config.items() | ||||||
|  |         if isinstance(v, (int, float, str, bool)) or v is None | ||||||
|  |     } | ||||||
|  |     return config | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def save_obj_config(obj, path) -> None: | ||||||
|  |     config = get_obj_config(obj) | ||||||
|  |     write_json(config, path) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def obj_config_to_str(obj) -> str: | ||||||
|  |     config = get_obj_config(obj) | ||||||
|  |     return "-".join([f"{k}={v}" for k, v in config.items()]) | ||||||
|  |  | ||||||
							
								
								
									
										134
									
								
								main.py
									
										
									
									
									
								
							
							
						
						
									
										134
									
								
								main.py
									
										
									
									
									
								
							|  | @ -3,6 +3,138 @@ | ||||||
| # | # | ||||||
| 
 | 
 | ||||||
| import os | import os | ||||||
|  | import time | ||||||
|  | 
 | ||||||
|  | from torch_geometric import seed_everything | ||||||
|  | from torch_geometric.data.makedirs import makedirs | ||||||
|  | from torch_geometric.explain import Explainer | ||||||
|  | from torch_geometric.explain.config import ThresholdConfig | ||||||
|  | from torch_geometric.graphgym.config import cfg | ||||||
|  | from torch_geometric.graphgym.utils.device import auto_select_device | ||||||
|  | 
 | ||||||
| from explaining_framework.config.explaining_config import explaining_cfg | from explaining_framework.config.explaining_config import explaining_cfg | ||||||
| from explaining_framework.utils.explaining.cmd_args import parse_args | from explaining_framework.utils.explaining.cmd_args import parse_args | ||||||
| from explaining_framework.utils.explaining.outline import parse_args | from explaining_framework.utils.explaining.outline import ExplainingOutline | ||||||
|  | from explaining_framework.utils.io import (obj_config_to_str, read_json, | ||||||
|  |                                            write_json, write_yaml) | ||||||
|  | from explaining_framework.utils.explanation.adjust import Adjust | ||||||
|  | 
 | ||||||
|  | # inference, time, force, | ||||||
|  | 
 | ||||||
|  | def get_pred(explanation,force=False): | ||||||
|  |     dict_ = explanation.to_dict() | ||||||
|  |     if dict_.get('pred') is None or dict_.get('pred_masked') or force: | ||||||
|  |         pred = explainer.get_prediction(explanation) | ||||||
|  |         pred_masked = explainer.get_masked_prediction(x=explanation.x,edge_index=explanation.edge_index,node_mask=explanation.node_mask,edge_mask=explanation.edge_mask) | ||||||
|  |         explanation.__setattr__('pred',pred) | ||||||
|  |         explanation.__setattr__('pred_masked',pred_masked) | ||||||
|  |         return explanation | ||||||
|  |     else: | ||||||
|  |         return explanation | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     args = parse_args() | ||||||
|  |     outline = ExplainingOutline(args.explaining_cfg_file) | ||||||
|  |     auto_select_device() | ||||||
|  | 
 | ||||||
|  |     # Load components | ||||||
|  |     dataset = outline.dataset.to(cfg.accelerator) | ||||||
|  |     model = outline.model.to(cfg.accelerator) | ||||||
|  |     model_info = outline.model_info | ||||||
|  |     metrics = outline.metrics | ||||||
|  |     explaining_algorithm = outline.explaining_algorithm | ||||||
|  |     attacks = outline.attacks | ||||||
|  |     explainer_cfg = outline.explainer_cfg | ||||||
|  |     model_signature = outline.model_signature | ||||||
|  | 
 | ||||||
|  |     # Set seed | ||||||
|  |     seed_everything(explaining_cfg.seed) | ||||||
|  | 
 | ||||||
|  |     # Global path | ||||||
|  |     global_path = os.path.join(explaining_cfg.out_dir, model_signature) | ||||||
|  |     makedirs(global_path) | ||||||
|  |     write_yaml(cfg, os.path.join(global_path, "config.yaml")) | ||||||
|  |         write_json(model_info, os.path.join(global_path, "info.json")) | ||||||
|  | 
 | ||||||
|  |     global_path = os.path.join(global_path, explaining_cfg.explainer.name+'_'+obj_config_to_str(explaining_algorithm)) | ||||||
|  |     makedirs(global_path) | ||||||
|  |     write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) | ||||||
|  |     write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) | ||||||
|  | 
 | ||||||
|  |     global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm)) | ||||||
|  |     makedirs(global_path) | ||||||
|  |     explainer = Explainer( | ||||||
|  |         model=model, | ||||||
|  |         algorithm=explaining_algorithm, | ||||||
|  |         explainer_config=dict( | ||||||
|  |             explanation_type=explaining_cfg.explanation_type, | ||||||
|  |             node_mask_type="object", | ||||||
|  |             edge_mask_type="object", | ||||||
|  |         ), | ||||||
|  |         model_config=dict( | ||||||
|  |             mode="regression", | ||||||
|  |             task_level=cfg.dataset.task, | ||||||
|  |             return_type=explaining_cfg.model_config.return_type, | ||||||
|  |         ), | ||||||
|  |     ) | ||||||
|  |     # Save explaining configuration | ||||||
|  |     for index, item in enumerate(dataset): | ||||||
|  |         save_raw_path = os.path.join(global_path, "raw") | ||||||
|  |         makedirs(save_raw_path) | ||||||
|  |         explanation_path = os.path.join(save_raw_path, f"{index}.json") | ||||||
|  | 
 | ||||||
|  |         if is_exists(explanation_path): | ||||||
|  |             if explaining_cfg.explainer.force: | ||||||
|  |                 explanation = explainer( | ||||||
|  |                     x=item.x, | ||||||
|  |                     edge_index=item.edge_index, | ||||||
|  |                     index=item.y, | ||||||
|  |                     target=item.y, | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 explanation = load_explanation(explanation_path) | ||||||
|  |         else: | ||||||
|  |             explanation = explainer( | ||||||
|  |                 x=item.x, | ||||||
|  |                 edge_index=item.edge_index, | ||||||
|  |                 index=item.y, | ||||||
|  |                 target=item.y, | ||||||
|  |             ) | ||||||
|  |         explanation = get_pred(explanation,force=False) | ||||||
|  |         save_explanation(explanation,explanation_path) | ||||||
|  |         for apply_relu in [True,False]: | ||||||
|  |             for apply_absolute in [True,False]: | ||||||
|  |                 adjust = Adjust(apply_relu=apply_relu,apply_absolute=apply_absolute) | ||||||
|  |                 save_raw_path = os.path.join(global_path,f'adjust-{obj_config_to_str(adjust)}') | ||||||
|  |                 makedirs(save_raw_path) | ||||||
|  |                 explanation  = adjust.forward(explanation) | ||||||
|  |                 explanation_path = os.path.join(save_raw_path, f"{index}.json") | ||||||
|  |                 explanation = get_pred(explanation,force=True) | ||||||
|  |                 save_explanation(explanation,explanation_path) | ||||||
|  | 
 | ||||||
|  |                 for threshold_approach in ['hard','topk','topk_hard']: | ||||||
|  |                     for threshold_value in explaining_cfg.threshold_config.value:  | ||||||
|  | 
 | ||||||
|  |                         masking_path =os.path.join(save_raw_path,f'threshold={threshold_approach}-value={value}') | ||||||
|  |                         exp_threshold_path = os.path.join(masking_path,f'{index}.json') | ||||||
|  | 
 | ||||||
|  |                         if is_exists(exp_threshold_path): | ||||||
|  |                             explanation = load_explanation(exp_threshold_path) | ||||||
|  |                         else: | ||||||
|  |                             threshold_conf = {'threshold_type':threshold_approach,'value':threshold_value} | ||||||
|  |                             explainer.threshold_config = ThresholdConfig.cast(threshold_conf)  | ||||||
|  | 
 | ||||||
|  |                             expl = copy.copy(explanation) | ||||||
|  |                             exp_threshold = explainer._post_process(expl) | ||||||
|  |                             exp_threshold= get_pred(exp_threshold,force=True) | ||||||
|  | 
 | ||||||
|  |                             save_explanation(exp_threshold,exp_threshold_path) | ||||||
|  |                         for metric in metrics: | ||||||
|  |                             metric_path =os.path.join(masking_path,f'{obj_config_to_str(metric)}') | ||||||
|  |                             if is_exists(os.path.join(metric_path,f'{index}.json')): | ||||||
|  |                                 continue | ||||||
|  |                             else: | ||||||
|  |                                 out = metric.forward(exp_threshold) | ||||||
|  |                                 write_json({f'{metric.name}':out}) | ||||||
|  | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 araison
						araison