Fixing many bugs

This commit is contained in:
araison 2023-01-10 18:49:38 +01:00
parent b7fb522255
commit 3d0d3ec451
3 changed files with 102 additions and 80 deletions

View File

@ -19,57 +19,57 @@ if "__main__" == __name__:
DATASET = [ DATASET = [
"CIFAR10", "CIFAR10",
"TRIANGLES", # "TRIANGLES",
"COLORS-3", # "COLORS-3",
"REDDIT-BINARY", # "REDDIT-BINARY",
"REDDIT-MULTI-5K", # "REDDIT-MULTI-5K",
"REDDIT-MULTI-12K", # "REDDIT-MULTI-12K",
"COLLAB", # "COLLAB",
"DBLP_v1", # "DBLP_v1",
"COIL-DEL", # "COIL-DEL",
"COIL-RAG", # "COIL-RAG",
"Fingerprint", # "Fingerprint",
"Letter-high", # "Letter-high",
"Letter-low", # "Letter-low",
"Letter-med", # "Letter-med",
"MSRC_9", "MSRC_9",
"MSRC_21", # "MSRC_21",
"MSRC_21C", "MSRC_21C",
"DD", # "DD",
"ENZYMES", # "ENZYMES",
"PROTEINS", "PROTEINS",
"QM9", # "QM9",
"MUTAG", # "MUTAG",
"Mutagenicity", # "Mutagenicity",
"AIDS", # "AIDS",
"PATTERN", # "PATTERN",
"CLUSTER", # "CLUSTER",
"MNIST", "MNIST",
"CIFAR10", "CIFAR10",
"TSP", # "TSP",
"CSL", # "CSL",
"KarateClub", # "KarateClub",
"CS", # "CS",
"Physics", # "Physics",
"BBBP", # "BBBP",
"Tox21", # "Tox21",
"HIV", # "HIV",
"PCBA", # "PCBA",
"MUV", # "MUV",
"BACE", # "BACE",
"SIDER", # "SIDER",
"ClinTox", # "ClinTox",
"AIFB", # "AIFB",
"AM", # "AM",
"MUTAG", # "MUTAG",
"BGS", # "BGS",
"FAUST", # "FAUST",
"DynamicFAUST", # "DynamicFAUST",
"ShapeNet", # "ShapeNet",
"ModelNet10", # "ModelNet10",
"ModelNet40", # "ModelNet40",
"PascalVOC-SP", # "PascalVOC-SP",
"COCO-SP", # "COCO-SP",
] ]
EXPLAINER = [ EXPLAINER = [
"CAM", "CAM",
@ -81,8 +81,8 @@ if "__main__" == __name__:
# "PGExplainer", # "PGExplainer",
"PGMExplainer", "PGMExplainer",
"RandomExplainer", "RandomExplainer",
"SubgraphX", # "SubgraphX",
"GraphMASK", # "GraphMASK",
"GNNExplainer", "GNNExplainer",
"EIXGNN", "EIXGNN",
"SCGNN", "SCGNN",
@ -113,7 +113,9 @@ if "__main__" == __name__:
explaining_cfg["model"] = {} explaining_cfg["model"] = {}
explaining_cfg["model"]["ckpt"] = string_to_python(model_kind) explaining_cfg["model"]["ckpt"] = string_to_python(model_kind)
explaining_cfg["model"]["path"] = string_to_python( explaining_cfg["model"]["path"] = string_to_python(
"/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid" # "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
"/home/SIC/araison/exps/pytorch_geometric/graphgym/results/"
# "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
) )
# explaining_cfg['out_dir']='./explanation' # explaining_cfg['out_dir']='./explanation'
# explaining_cfg['print']='both' # explaining_cfg['print']='both'

View File

@ -183,7 +183,6 @@ class ExplainingOutline(object):
return None return None
def load_indexes(self): def load_indexes(self):
item = self.explaining_cfg.dataset.item item = self.explaining_cfg.dataset.item
if isinstance(item, (list, int)): if isinstance(item, (list, int)):
indexes = item indexes = item
@ -566,17 +565,22 @@ class ExplainingOutline(object):
def get_explanation(self, item: Data, path: str): def get_explanation(self, item: Data, path: str):
if is_exists(path): if is_exists(path):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
explanation = _get_explanation(self.explainer, item) try:
if explanation is None: explanation = _get_explanation(self.explainer, item)
logging.warning( if explanation is None:
" EXP::Generated; Path %s; FAILED", logging.error(
(path), " EXP::Generated; Path %s; FAILED",
) (path),
else: )
logging.debug(
"EXP::Generated; Path %s; SUCCEEDED", else:
(path), logging.debug(
) "EXP::Generated; Path %s; SUCCEEDED",
(path),
)
except Exception as e:
logging.error(str(e))
return None
else: else:
explanation = _load_explanation(path) explanation = _load_explanation(path)
logging.debug( logging.debug(
@ -585,13 +589,17 @@ class ExplainingOutline(object):
) )
explanation = explanation.to(self.cfg.accelerator) explanation = explanation.to(self.cfg.accelerator)
else: else:
explanation = _get_explanation(self.explainer, item) try:
get_pred(self.explainer, explanation) explanation = _get_explanation(self.explainer, item)
_save_explanation(explanation, path) get_pred(self.explainer, explanation)
logging.debug( _save_explanation(explanation, path)
"EXP::Generated; Path %s; SUCCEEDED", logging.debug(
(path), "EXP::Generated; Path %s; SUCCEEDED",
) (path),
)
except Exception as e:
logging.error(str(e))
return None
return explanation return explanation
@ -694,12 +702,15 @@ class ExplainingOutline(object):
def get_attack(self, attack: Attack, item: Data, path: str): def get_attack(self, attack: Attack, item: Data, path: str):
if is_exists(path): if is_exists(path):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
data_attack = attack.get_attacked_prediction(item) try:
logging.debug( data_attack = attack.get_attacked_prediction(item)
"ATTACK::Generated %s; Path %s; SUCCEEDED", logging.debug(
(path), "ATTACK::Generated %s; Path %s; SUCCEEDED",
) (path),
)
except Exception as e:
logging.error(str(e))
return None
else: else:
data_attack = _load_explanation(path) data_attack = _load_explanation(path)
logging.debug( logging.debug(
@ -707,12 +718,17 @@ class ExplainingOutline(object):
(path), (path),
) )
else: else:
data_attack = attack.get_attacked_prediction(item) try:
_save_explanation(data_attack, path) data_attack = attack.get_attacked_prediction(item)
logging.debug( _save_explanation(data_attack, path)
"ATTACK::Generated %s; Path %s; SUCCEEDED", logging.debug(
(path), "ATTACK::Generated %s; Path %s; SUCCEEDED",
) (path),
)
except Exception as e:
logging.error(str(e))
return None
return data_attack return data_attack
def setup_experiment(self): def setup_experiment(self):

View File

@ -40,6 +40,8 @@ if __name__ == "__main__":
data_attack = outline.get_attack( data_attack = outline.get_attack(
attack=attack, item=item, path=data_attack_path attack=attack, item=item, path=data_attack_path
) )
if data_attack is None:
continue
item, index = outline.get_item() item, index = outline.get_item()
@ -57,6 +59,8 @@ if __name__ == "__main__":
attack_data = outline.get_attack( attack_data = outline.get_attack(
attack=attack, item=item, path=data_attack_path_ attack=attack, item=item, path=data_attack_path_
) )
if attack_data is None:
continue
exp = outline.get_explanation(item=attack_data, path=data_attack_path_) exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
pbar.update(1) pbar.update(1)
if exp is None: if exp is None: