Fixing many bugs

This commit is contained in:
araison 2023-01-10 18:49:38 +01:00
parent b7fb522255
commit 3d0d3ec451
3 changed files with 102 additions and 80 deletions

View File

@ -19,57 +19,57 @@ if "__main__" == __name__:
DATASET = [
"CIFAR10",
"TRIANGLES",
"COLORS-3",
"REDDIT-BINARY",
"REDDIT-MULTI-5K",
"REDDIT-MULTI-12K",
"COLLAB",
"DBLP_v1",
"COIL-DEL",
"COIL-RAG",
"Fingerprint",
"Letter-high",
"Letter-low",
"Letter-med",
# "TRIANGLES",
# "COLORS-3",
# "REDDIT-BINARY",
# "REDDIT-MULTI-5K",
# "REDDIT-MULTI-12K",
# "COLLAB",
# "DBLP_v1",
# "COIL-DEL",
# "COIL-RAG",
# "Fingerprint",
# "Letter-high",
# "Letter-low",
# "Letter-med",
"MSRC_9",
"MSRC_21",
# "MSRC_21",
"MSRC_21C",
"DD",
"ENZYMES",
# "DD",
# "ENZYMES",
"PROTEINS",
"QM9",
"MUTAG",
"Mutagenicity",
"AIDS",
"PATTERN",
"CLUSTER",
# "QM9",
# "MUTAG",
# "Mutagenicity",
# "AIDS",
# "PATTERN",
# "CLUSTER",
"MNIST",
"CIFAR10",
"TSP",
"CSL",
"KarateClub",
"CS",
"Physics",
"BBBP",
"Tox21",
"HIV",
"PCBA",
"MUV",
"BACE",
"SIDER",
"ClinTox",
"AIFB",
"AM",
"MUTAG",
"BGS",
"FAUST",
"DynamicFAUST",
"ShapeNet",
"ModelNet10",
"ModelNet40",
"PascalVOC-SP",
"COCO-SP",
# "TSP",
# "CSL",
# "KarateClub",
# "CS",
# "Physics",
# "BBBP",
# "Tox21",
# "HIV",
# "PCBA",
# "MUV",
# "BACE",
# "SIDER",
# "ClinTox",
# "AIFB",
# "AM",
# "MUTAG",
# "BGS",
# "FAUST",
# "DynamicFAUST",
# "ShapeNet",
# "ModelNet10",
# "ModelNet40",
# "PascalVOC-SP",
# "COCO-SP",
]
EXPLAINER = [
"CAM",
@ -81,8 +81,8 @@ if "__main__" == __name__:
# "PGExplainer",
"PGMExplainer",
"RandomExplainer",
"SubgraphX",
"GraphMASK",
# "SubgraphX",
# "GraphMASK",
"GNNExplainer",
"EIXGNN",
"SCGNN",
@ -113,7 +113,9 @@ if "__main__" == __name__:
explaining_cfg["model"] = {}
explaining_cfg["model"]["ckpt"] = string_to_python(model_kind)
explaining_cfg["model"]["path"] = string_to_python(
"/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
# "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
"/home/SIC/araison/exps/pytorch_geometric/graphgym/results/"
# "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
)
# explaining_cfg['out_dir']='./explanation'
# explaining_cfg['print']='both'

View File

@ -183,7 +183,6 @@ class ExplainingOutline(object):
return None
def load_indexes(self):
item = self.explaining_cfg.dataset.item
if isinstance(item, (list, int)):
indexes = item
@ -566,17 +565,22 @@ class ExplainingOutline(object):
def get_explanation(self, item: Data, path: str):
if is_exists(path):
if self.explaining_cfg.explainer.force:
try:
explanation = _get_explanation(self.explainer, item)
if explanation is None:
logging.warning(
logging.error(
" EXP::Generated; Path %s; FAILED",
(path),
)
else:
logging.debug(
"EXP::Generated; Path %s; SUCCEEDED",
(path),
)
except Exception as e:
logging.error(str(e))
return None
else:
explanation = _load_explanation(path)
logging.debug(
@ -585,6 +589,7 @@ class ExplainingOutline(object):
)
explanation = explanation.to(self.cfg.accelerator)
else:
try:
explanation = _get_explanation(self.explainer, item)
get_pred(self.explainer, explanation)
_save_explanation(explanation, path)
@ -592,6 +597,9 @@ class ExplainingOutline(object):
"EXP::Generated; Path %s; SUCCEEDED",
(path),
)
except Exception as e:
logging.error(str(e))
return None
return explanation
@ -694,12 +702,15 @@ class ExplainingOutline(object):
def get_attack(self, attack: Attack, item: Data, path: str):
if is_exists(path):
if self.explaining_cfg.explainer.force:
try:
data_attack = attack.get_attacked_prediction(item)
logging.debug(
"ATTACK::Generated %s; Path %s; SUCCEEDED",
(path),
)
except Exception as e:
logging.error(str(e))
return None
else:
data_attack = _load_explanation(path)
logging.debug(
@ -707,12 +718,17 @@ class ExplainingOutline(object):
(path),
)
else:
try:
data_attack = attack.get_attacked_prediction(item)
_save_explanation(data_attack, path)
logging.debug(
"ATTACK::Generated %s; Path %s; SUCCEEDED",
(path),
)
except Exception as e:
logging.error(str(e))
return None
return data_attack
def setup_experiment(self):

View File

@ -40,6 +40,8 @@ if __name__ == "__main__":
data_attack = outline.get_attack(
attack=attack, item=item, path=data_attack_path
)
if data_attack is None:
continue
item, index = outline.get_item()
@ -57,6 +59,8 @@ if __name__ == "__main__":
attack_data = outline.get_attack(
attack=attack, item=item, path=data_attack_path_
)
if attack_data is None:
continue
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
pbar.update(1)
if exp is None: