Fixing many bugs
This commit is contained in:
		
							parent
							
								
									b7fb522255
								
							
						
					
					
						commit
						3d0d3ec451
					
				
					 3 changed files with 102 additions and 80 deletions
				
			
		|  | @ -19,57 +19,57 @@ if "__main__" == __name__: | ||||||
| 
 | 
 | ||||||
|     DATASET = [ |     DATASET = [ | ||||||
|         "CIFAR10", |         "CIFAR10", | ||||||
|         "TRIANGLES", |         # "TRIANGLES", | ||||||
|         "COLORS-3", |         # "COLORS-3", | ||||||
|         "REDDIT-BINARY", |         # "REDDIT-BINARY", | ||||||
|         "REDDIT-MULTI-5K", |         # "REDDIT-MULTI-5K", | ||||||
|         "REDDIT-MULTI-12K", |         # "REDDIT-MULTI-12K", | ||||||
|         "COLLAB", |         # "COLLAB", | ||||||
|         "DBLP_v1", |         # "DBLP_v1", | ||||||
|         "COIL-DEL", |         # "COIL-DEL", | ||||||
|         "COIL-RAG", |         # "COIL-RAG", | ||||||
|         "Fingerprint", |         # "Fingerprint", | ||||||
|         "Letter-high", |         # "Letter-high", | ||||||
|         "Letter-low", |         # "Letter-low", | ||||||
|         "Letter-med", |         # "Letter-med", | ||||||
|         "MSRC_9", |         "MSRC_9", | ||||||
|         "MSRC_21", |         # "MSRC_21", | ||||||
|         "MSRC_21C", |         "MSRC_21C", | ||||||
|         "DD", |         # "DD", | ||||||
|         "ENZYMES", |         # "ENZYMES", | ||||||
|         "PROTEINS", |         "PROTEINS", | ||||||
|         "QM9", |         # "QM9", | ||||||
|         "MUTAG", |         # "MUTAG", | ||||||
|         "Mutagenicity", |         # "Mutagenicity", | ||||||
|         "AIDS", |         # "AIDS", | ||||||
|         "PATTERN", |         # "PATTERN", | ||||||
|         "CLUSTER", |         # "CLUSTER", | ||||||
|         "MNIST", |         "MNIST", | ||||||
|         "CIFAR10", |         "CIFAR10", | ||||||
|         "TSP", |         # "TSP", | ||||||
|         "CSL", |         # "CSL", | ||||||
|         "KarateClub", |         # "KarateClub", | ||||||
|         "CS", |         # "CS", | ||||||
|         "Physics", |         # "Physics", | ||||||
|         "BBBP", |         # "BBBP", | ||||||
|         "Tox21", |         # "Tox21", | ||||||
|         "HIV", |         # "HIV", | ||||||
|         "PCBA", |         # "PCBA", | ||||||
|         "MUV", |         # "MUV", | ||||||
|         "BACE", |         # "BACE", | ||||||
|         "SIDER", |         # "SIDER", | ||||||
|         "ClinTox", |         # "ClinTox", | ||||||
|         "AIFB", |         # "AIFB", | ||||||
|         "AM", |         # "AM", | ||||||
|         "MUTAG", |         # "MUTAG", | ||||||
|         "BGS", |         # "BGS", | ||||||
|         "FAUST", |         # "FAUST", | ||||||
|         "DynamicFAUST", |         # "DynamicFAUST", | ||||||
|         "ShapeNet", |         # "ShapeNet", | ||||||
|         "ModelNet10", |         # "ModelNet10", | ||||||
|         "ModelNet40", |         # "ModelNet40", | ||||||
|         "PascalVOC-SP", |         # "PascalVOC-SP", | ||||||
|         "COCO-SP", |         # "COCO-SP", | ||||||
|     ] |     ] | ||||||
|     EXPLAINER = [ |     EXPLAINER = [ | ||||||
|         "CAM", |         "CAM", | ||||||
|  | @ -81,8 +81,8 @@ if "__main__" == __name__: | ||||||
|         # "PGExplainer", |         # "PGExplainer", | ||||||
|         "PGMExplainer", |         "PGMExplainer", | ||||||
|         "RandomExplainer", |         "RandomExplainer", | ||||||
|         "SubgraphX", |         # "SubgraphX", | ||||||
|         "GraphMASK", |         # "GraphMASK", | ||||||
|         "GNNExplainer", |         "GNNExplainer", | ||||||
|         "EIXGNN", |         "EIXGNN", | ||||||
|         "SCGNN", |         "SCGNN", | ||||||
|  | @ -113,7 +113,9 @@ if "__main__" == __name__: | ||||||
|                 explaining_cfg["model"] = {} |                 explaining_cfg["model"] = {} | ||||||
|                 explaining_cfg["model"]["ckpt"] = string_to_python(model_kind) |                 explaining_cfg["model"]["ckpt"] = string_to_python(model_kind) | ||||||
|                 explaining_cfg["model"]["path"] = string_to_python( |                 explaining_cfg["model"]["path"] = string_to_python( | ||||||
|                     "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" |                     # "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" | ||||||
|  |                     "/home/SIC/araison/exps/pytorch_geometric/graphgym/results/" | ||||||
|  |                     # "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" | ||||||
|                 ) |                 ) | ||||||
|                 # explaining_cfg['out_dir']='./explanation' |                 # explaining_cfg['out_dir']='./explanation' | ||||||
|                 # explaining_cfg['print']='both' |                 # explaining_cfg['print']='both' | ||||||
|  |  | ||||||
|  | @ -183,7 +183,6 @@ class ExplainingOutline(object): | ||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
|     def load_indexes(self): |     def load_indexes(self): | ||||||
| 
 |  | ||||||
|         item = self.explaining_cfg.dataset.item |         item = self.explaining_cfg.dataset.item | ||||||
|         if isinstance(item, (list, int)): |         if isinstance(item, (list, int)): | ||||||
|             indexes = item |             indexes = item | ||||||
|  | @ -566,17 +565,22 @@ class ExplainingOutline(object): | ||||||
|     def get_explanation(self, item: Data, path: str): |     def get_explanation(self, item: Data, path: str): | ||||||
|         if is_exists(path): |         if is_exists(path): | ||||||
|             if self.explaining_cfg.explainer.force: |             if self.explaining_cfg.explainer.force: | ||||||
|                 explanation = _get_explanation(self.explainer, item) |                 try: | ||||||
|                 if explanation is None: |                     explanation = _get_explanation(self.explainer, item) | ||||||
|                     logging.warning( |                     if explanation is None: | ||||||
|                         " EXP::Generated; Path %s; FAILED", |                         logging.error( | ||||||
|                         (path), |                             " EXP::Generated; Path %s; FAILED", | ||||||
|                     ) |                             (path), | ||||||
|                 else: |                         ) | ||||||
|                     logging.debug( | 
 | ||||||
|                         "EXP::Generated; Path %s; SUCCEEDED", |                     else: | ||||||
|                         (path), |                         logging.debug( | ||||||
|                     ) |                             "EXP::Generated; Path %s; SUCCEEDED", | ||||||
|  |                             (path), | ||||||
|  |                         ) | ||||||
|  |                 except Exception as e: | ||||||
|  |                     logging.error(str(e)) | ||||||
|  |                     return None | ||||||
|             else: |             else: | ||||||
|                 explanation = _load_explanation(path) |                 explanation = _load_explanation(path) | ||||||
|                 logging.debug( |                 logging.debug( | ||||||
|  | @ -585,13 +589,17 @@ class ExplainingOutline(object): | ||||||
|                 ) |                 ) | ||||||
|                 explanation = explanation.to(self.cfg.accelerator) |                 explanation = explanation.to(self.cfg.accelerator) | ||||||
|         else: |         else: | ||||||
|             explanation = _get_explanation(self.explainer, item) |             try: | ||||||
|             get_pred(self.explainer, explanation) |                 explanation = _get_explanation(self.explainer, item) | ||||||
|             _save_explanation(explanation, path) |                 get_pred(self.explainer, explanation) | ||||||
|             logging.debug( |                 _save_explanation(explanation, path) | ||||||
|                 "EXP::Generated; Path %s; SUCCEEDED", |                 logging.debug( | ||||||
|                 (path), |                     "EXP::Generated; Path %s; SUCCEEDED", | ||||||
|             ) |                     (path), | ||||||
|  |                 ) | ||||||
|  |             except Exception as e: | ||||||
|  |                 logging.error(str(e)) | ||||||
|  |                 return None | ||||||
| 
 | 
 | ||||||
|         return explanation |         return explanation | ||||||
| 
 | 
 | ||||||
|  | @ -694,12 +702,15 @@ class ExplainingOutline(object): | ||||||
|     def get_attack(self, attack: Attack, item: Data, path: str): |     def get_attack(self, attack: Attack, item: Data, path: str): | ||||||
|         if is_exists(path): |         if is_exists(path): | ||||||
|             if self.explaining_cfg.explainer.force: |             if self.explaining_cfg.explainer.force: | ||||||
|                 data_attack = attack.get_attacked_prediction(item) |                 try: | ||||||
|                 logging.debug( |                     data_attack = attack.get_attacked_prediction(item) | ||||||
|                     "ATTACK::Generated %s; Path %s; SUCCEEDED", |                     logging.debug( | ||||||
|                     (path), |                         "ATTACK::Generated %s; Path %s; SUCCEEDED", | ||||||
|                 ) |                         (path), | ||||||
| 
 |                     ) | ||||||
|  |                 except Exception as e: | ||||||
|  |                     logging.error(str(e)) | ||||||
|  |                     return None | ||||||
|             else: |             else: | ||||||
|                 data_attack = _load_explanation(path) |                 data_attack = _load_explanation(path) | ||||||
|                 logging.debug( |                 logging.debug( | ||||||
|  | @ -707,12 +718,17 @@ class ExplainingOutline(object): | ||||||
|                     (path), |                     (path), | ||||||
|                 ) |                 ) | ||||||
|         else: |         else: | ||||||
|             data_attack = attack.get_attacked_prediction(item) |             try: | ||||||
|             _save_explanation(data_attack, path) |                 data_attack = attack.get_attacked_prediction(item) | ||||||
|             logging.debug( |                 _save_explanation(data_attack, path) | ||||||
|                 "ATTACK::Generated %s; Path %s; SUCCEEDED", |                 logging.debug( | ||||||
|                 (path), |                     "ATTACK::Generated %s; Path %s; SUCCEEDED", | ||||||
|             ) |                     (path), | ||||||
|  |                 ) | ||||||
|  |             except Exception as e: | ||||||
|  |                 logging.error(str(e)) | ||||||
|  |                 return None | ||||||
|  | 
 | ||||||
|         return data_attack |         return data_attack | ||||||
| 
 | 
 | ||||||
|     def setup_experiment(self): |     def setup_experiment(self): | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								main.py
									
										
									
									
									
								
							
							
						
						
									
										4
									
								
								main.py
									
										
									
									
									
								
							|  | @ -40,6 +40,8 @@ if __name__ == "__main__": | ||||||
|             data_attack = outline.get_attack( |             data_attack = outline.get_attack( | ||||||
|                 attack=attack, item=item, path=data_attack_path |                 attack=attack, item=item, path=data_attack_path | ||||||
|             ) |             ) | ||||||
|  |             if data_attack is None: | ||||||
|  |                 continue | ||||||
| 
 | 
 | ||||||
|         item, index = outline.get_item() |         item, index = outline.get_item() | ||||||
| 
 | 
 | ||||||
|  | @ -57,6 +59,8 @@ if __name__ == "__main__": | ||||||
|             attack_data = outline.get_attack( |             attack_data = outline.get_attack( | ||||||
|                 attack=attack, item=item, path=data_attack_path_ |                 attack=attack, item=item, path=data_attack_path_ | ||||||
|             ) |             ) | ||||||
|  |             if attack_data is None: | ||||||
|  |                 continue | ||||||
|             exp = outline.get_explanation(item=attack_data, path=data_attack_path_) |             exp = outline.get_explanation(item=attack_data, path=data_attack_path_) | ||||||
|             pbar.update(1) |             pbar.update(1) | ||||||
|             if exp is None: |             if exp is None: | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 araison
						araison