From 3d0d3ec451bdbe6fe09d356d6779af037b03ab86 Mon Sep 17 00:00:00 2001 From: araison Date: Tue, 10 Jan 2023 18:49:38 +0100 Subject: [PATCH] Fixing many bugs --- explaining_framework/utils/config_gen.py | 100 +++++++++--------- .../utils/explaining/outline.py | 78 ++++++++------ main.py | 4 + 3 files changed, 102 insertions(+), 80 deletions(-) diff --git a/explaining_framework/utils/config_gen.py b/explaining_framework/utils/config_gen.py index 65189b1..fc5367d 100644 --- a/explaining_framework/utils/config_gen.py +++ b/explaining_framework/utils/config_gen.py @@ -19,57 +19,57 @@ if "__main__" == __name__: DATASET = [ "CIFAR10", - "TRIANGLES", - "COLORS-3", - "REDDIT-BINARY", - "REDDIT-MULTI-5K", - "REDDIT-MULTI-12K", - "COLLAB", - "DBLP_v1", - "COIL-DEL", - "COIL-RAG", - "Fingerprint", - "Letter-high", - "Letter-low", - "Letter-med", + # "TRIANGLES", + # "COLORS-3", + # "REDDIT-BINARY", + # "REDDIT-MULTI-5K", + # "REDDIT-MULTI-12K", + # "COLLAB", + # "DBLP_v1", + # "COIL-DEL", + # "COIL-RAG", + # "Fingerprint", + # "Letter-high", + # "Letter-low", + # "Letter-med", "MSRC_9", - "MSRC_21", + # "MSRC_21", "MSRC_21C", - "DD", - "ENZYMES", + # "DD", + # "ENZYMES", "PROTEINS", - "QM9", - "MUTAG", - "Mutagenicity", - "AIDS", - "PATTERN", - "CLUSTER", + # "QM9", + # "MUTAG", + # "Mutagenicity", + # "AIDS", + # "PATTERN", + # "CLUSTER", "MNIST", "CIFAR10", - "TSP", - "CSL", - "KarateClub", - "CS", - "Physics", - "BBBP", - "Tox21", - "HIV", - "PCBA", - "MUV", - "BACE", - "SIDER", - "ClinTox", - "AIFB", - "AM", - "MUTAG", - "BGS", - "FAUST", - "DynamicFAUST", - "ShapeNet", - "ModelNet10", - "ModelNet40", - "PascalVOC-SP", - "COCO-SP", + # "TSP", + # "CSL", + # "KarateClub", + # "CS", + # "Physics", + # "BBBP", + # "Tox21", + # "HIV", + # "PCBA", + # "MUV", + # "BACE", + # "SIDER", + # "ClinTox", + # "AIFB", + # "AM", + # "MUTAG", + # "BGS", + # "FAUST", + # "DynamicFAUST", + # "ShapeNet", + # "ModelNet10", + # "ModelNet40", + # "PascalVOC-SP", + # "COCO-SP", ] EXPLAINER = [ "CAM", @@ -81,8 +81,8 @@ if "__main__" == __name__: # "PGExplainer", "PGMExplainer", "RandomExplainer", - "SubgraphX", - "GraphMASK", + # "SubgraphX", + # "GraphMASK", "GNNExplainer", "EIXGNN", "SCGNN", @@ -113,7 +113,9 @@ if "__main__" == __name__: explaining_cfg["model"] = {} explaining_cfg["model"]["ckpt"] = string_to_python(model_kind) explaining_cfg["model"]["path"] = string_to_python( - "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" + # "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" + "/home/SIC/araison/exps/pytorch_geometric/graphgym/results/" + # "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" ) # explaining_cfg['out_dir']='./explanation' # explaining_cfg['print']='both' diff --git a/explaining_framework/utils/explaining/outline.py b/explaining_framework/utils/explaining/outline.py index 10413fc..d1ba6a3 100644 --- a/explaining_framework/utils/explaining/outline.py +++ b/explaining_framework/utils/explaining/outline.py @@ -183,7 +183,6 @@ class ExplainingOutline(object): return None def load_indexes(self): - item = self.explaining_cfg.dataset.item if isinstance(item, (list, int)): indexes = item @@ -566,17 +565,22 @@ class ExplainingOutline(object): def get_explanation(self, item: Data, path: str): if is_exists(path): if self.explaining_cfg.explainer.force: - explanation = _get_explanation(self.explainer, item) - if explanation is None: - logging.warning( - " EXP::Generated; Path %s; FAILED", - (path), - ) - else: - logging.debug( - "EXP::Generated; Path %s; SUCCEEDED", - (path), - ) + try: + explanation = _get_explanation(self.explainer, item) + if explanation is None: + logging.error( + " EXP::Generated; Path %s; FAILED", + (path), + ) + + else: + logging.debug( + "EXP::Generated; Path %s; SUCCEEDED", + (path), + ) + except Exception as e: + logging.error(str(e)) + return None else: explanation = _load_explanation(path) logging.debug( @@ -585,13 +589,17 @@ class ExplainingOutline(object): ) explanation = explanation.to(self.cfg.accelerator) else: - explanation = _get_explanation(self.explainer, item) - get_pred(self.explainer, explanation) - _save_explanation(explanation, path) - logging.debug( - "EXP::Generated; Path %s; SUCCEEDED", - (path), - ) + try: + explanation = _get_explanation(self.explainer, item) + get_pred(self.explainer, explanation) + _save_explanation(explanation, path) + logging.debug( + "EXP::Generated; Path %s; SUCCEEDED", + (path), + ) + except Exception as e: + logging.error(str(e)) + return None return explanation @@ -694,12 +702,15 @@ class ExplainingOutline(object): def get_attack(self, attack: Attack, item: Data, path: str): if is_exists(path): if self.explaining_cfg.explainer.force: - data_attack = attack.get_attacked_prediction(item) - logging.debug( - "ATTACK::Generated %s; Path %s; SUCCEEDED", - (path), - ) - + try: + data_attack = attack.get_attacked_prediction(item) + logging.debug( + "ATTACK::Generated %s; Path %s; SUCCEEDED", + (path), + ) + except Exception as e: + logging.error(str(e)) + return None else: data_attack = _load_explanation(path) logging.debug( @@ -707,12 +718,17 @@ class ExplainingOutline(object): (path), ) else: - data_attack = attack.get_attacked_prediction(item) - _save_explanation(data_attack, path) - logging.debug( - "ATTACK::Generated %s; Path %s; SUCCEEDED", - (path), - ) + try: + data_attack = attack.get_attacked_prediction(item) + _save_explanation(data_attack, path) + logging.debug( + "ATTACK::Generated %s; Path %s; SUCCEEDED", + (path), + ) + except Exception as e: + logging.error(str(e)) + return None + return data_attack def setup_experiment(self): diff --git a/main.py b/main.py index b361fe2..a8be4a9 100644 --- a/main.py +++ b/main.py @@ -40,6 +40,8 @@ if __name__ == "__main__": data_attack = outline.get_attack( attack=attack, item=item, path=data_attack_path ) + if data_attack is None: + continue item, index = outline.get_item() @@ -57,6 +59,8 @@ if __name__ == "__main__": attack_data = outline.get_attack( attack=attack, item=item, path=data_attack_path_ ) + if attack_data is None: + continue exp = outline.get_explanation(item=attack_data, path=data_attack_path_) pbar.update(1) if exp is None: