Fixing many bugs
This commit is contained in:
parent
b7fb522255
commit
3d0d3ec451
@ -19,57 +19,57 @@ if "__main__" == __name__:
|
||||
|
||||
DATASET = [
|
||||
"CIFAR10",
|
||||
"TRIANGLES",
|
||||
"COLORS-3",
|
||||
"REDDIT-BINARY",
|
||||
"REDDIT-MULTI-5K",
|
||||
"REDDIT-MULTI-12K",
|
||||
"COLLAB",
|
||||
"DBLP_v1",
|
||||
"COIL-DEL",
|
||||
"COIL-RAG",
|
||||
"Fingerprint",
|
||||
"Letter-high",
|
||||
"Letter-low",
|
||||
"Letter-med",
|
||||
# "TRIANGLES",
|
||||
# "COLORS-3",
|
||||
# "REDDIT-BINARY",
|
||||
# "REDDIT-MULTI-5K",
|
||||
# "REDDIT-MULTI-12K",
|
||||
# "COLLAB",
|
||||
# "DBLP_v1",
|
||||
# "COIL-DEL",
|
||||
# "COIL-RAG",
|
||||
# "Fingerprint",
|
||||
# "Letter-high",
|
||||
# "Letter-low",
|
||||
# "Letter-med",
|
||||
"MSRC_9",
|
||||
"MSRC_21",
|
||||
# "MSRC_21",
|
||||
"MSRC_21C",
|
||||
"DD",
|
||||
"ENZYMES",
|
||||
# "DD",
|
||||
# "ENZYMES",
|
||||
"PROTEINS",
|
||||
"QM9",
|
||||
"MUTAG",
|
||||
"Mutagenicity",
|
||||
"AIDS",
|
||||
"PATTERN",
|
||||
"CLUSTER",
|
||||
# "QM9",
|
||||
# "MUTAG",
|
||||
# "Mutagenicity",
|
||||
# "AIDS",
|
||||
# "PATTERN",
|
||||
# "CLUSTER",
|
||||
"MNIST",
|
||||
"CIFAR10",
|
||||
"TSP",
|
||||
"CSL",
|
||||
"KarateClub",
|
||||
"CS",
|
||||
"Physics",
|
||||
"BBBP",
|
||||
"Tox21",
|
||||
"HIV",
|
||||
"PCBA",
|
||||
"MUV",
|
||||
"BACE",
|
||||
"SIDER",
|
||||
"ClinTox",
|
||||
"AIFB",
|
||||
"AM",
|
||||
"MUTAG",
|
||||
"BGS",
|
||||
"FAUST",
|
||||
"DynamicFAUST",
|
||||
"ShapeNet",
|
||||
"ModelNet10",
|
||||
"ModelNet40",
|
||||
"PascalVOC-SP",
|
||||
"COCO-SP",
|
||||
# "TSP",
|
||||
# "CSL",
|
||||
# "KarateClub",
|
||||
# "CS",
|
||||
# "Physics",
|
||||
# "BBBP",
|
||||
# "Tox21",
|
||||
# "HIV",
|
||||
# "PCBA",
|
||||
# "MUV",
|
||||
# "BACE",
|
||||
# "SIDER",
|
||||
# "ClinTox",
|
||||
# "AIFB",
|
||||
# "AM",
|
||||
# "MUTAG",
|
||||
# "BGS",
|
||||
# "FAUST",
|
||||
# "DynamicFAUST",
|
||||
# "ShapeNet",
|
||||
# "ModelNet10",
|
||||
# "ModelNet40",
|
||||
# "PascalVOC-SP",
|
||||
# "COCO-SP",
|
||||
]
|
||||
EXPLAINER = [
|
||||
"CAM",
|
||||
@ -81,8 +81,8 @@ if "__main__" == __name__:
|
||||
# "PGExplainer",
|
||||
"PGMExplainer",
|
||||
"RandomExplainer",
|
||||
"SubgraphX",
|
||||
"GraphMASK",
|
||||
# "SubgraphX",
|
||||
# "GraphMASK",
|
||||
"GNNExplainer",
|
||||
"EIXGNN",
|
||||
"SCGNN",
|
||||
@ -113,7 +113,9 @@ if "__main__" == __name__:
|
||||
explaining_cfg["model"] = {}
|
||||
explaining_cfg["model"]["ckpt"] = string_to_python(model_kind)
|
||||
explaining_cfg["model"]["path"] = string_to_python(
|
||||
"/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
||||
# "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
||||
"/home/SIC/araison/exps/pytorch_geometric/graphgym/results/"
|
||||
# "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
||||
)
|
||||
# explaining_cfg['out_dir']='./explanation'
|
||||
# explaining_cfg['print']='both'
|
||||
|
@ -183,7 +183,6 @@ class ExplainingOutline(object):
|
||||
return None
|
||||
|
||||
def load_indexes(self):
|
||||
|
||||
item = self.explaining_cfg.dataset.item
|
||||
if isinstance(item, (list, int)):
|
||||
indexes = item
|
||||
@ -566,17 +565,22 @@ class ExplainingOutline(object):
|
||||
def get_explanation(self, item: Data, path: str):
|
||||
if is_exists(path):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
explanation = _get_explanation(self.explainer, item)
|
||||
if explanation is None:
|
||||
logging.warning(
|
||||
" EXP::Generated; Path %s; FAILED",
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
logging.debug(
|
||||
"EXP::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
try:
|
||||
explanation = _get_explanation(self.explainer, item)
|
||||
if explanation is None:
|
||||
logging.error(
|
||||
" EXP::Generated; Path %s; FAILED",
|
||||
(path),
|
||||
)
|
||||
|
||||
else:
|
||||
logging.debug(
|
||||
"EXP::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
return None
|
||||
else:
|
||||
explanation = _load_explanation(path)
|
||||
logging.debug(
|
||||
@ -585,13 +589,17 @@ class ExplainingOutline(object):
|
||||
)
|
||||
explanation = explanation.to(self.cfg.accelerator)
|
||||
else:
|
||||
explanation = _get_explanation(self.explainer, item)
|
||||
get_pred(self.explainer, explanation)
|
||||
_save_explanation(explanation, path)
|
||||
logging.debug(
|
||||
"EXP::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
try:
|
||||
explanation = _get_explanation(self.explainer, item)
|
||||
get_pred(self.explainer, explanation)
|
||||
_save_explanation(explanation, path)
|
||||
logging.debug(
|
||||
"EXP::Generated; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
return None
|
||||
|
||||
return explanation
|
||||
|
||||
@ -694,12 +702,15 @@ class ExplainingOutline(object):
|
||||
def get_attack(self, attack: Attack, item: Data, path: str):
|
||||
if is_exists(path):
|
||||
if self.explaining_cfg.explainer.force:
|
||||
data_attack = attack.get_attacked_prediction(item)
|
||||
logging.debug(
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
|
||||
try:
|
||||
data_attack = attack.get_attacked_prediction(item)
|
||||
logging.debug(
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
return None
|
||||
else:
|
||||
data_attack = _load_explanation(path)
|
||||
logging.debug(
|
||||
@ -707,12 +718,17 @@ class ExplainingOutline(object):
|
||||
(path),
|
||||
)
|
||||
else:
|
||||
data_attack = attack.get_attacked_prediction(item)
|
||||
_save_explanation(data_attack, path)
|
||||
logging.debug(
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
try:
|
||||
data_attack = attack.get_attacked_prediction(item)
|
||||
_save_explanation(data_attack, path)
|
||||
logging.debug(
|
||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||
(path),
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
return None
|
||||
|
||||
return data_attack
|
||||
|
||||
def setup_experiment(self):
|
||||
|
4
main.py
4
main.py
@ -40,6 +40,8 @@ if __name__ == "__main__":
|
||||
data_attack = outline.get_attack(
|
||||
attack=attack, item=item, path=data_attack_path
|
||||
)
|
||||
if data_attack is None:
|
||||
continue
|
||||
|
||||
item, index = outline.get_item()
|
||||
|
||||
@ -57,6 +59,8 @@ if __name__ == "__main__":
|
||||
attack_data = outline.get_attack(
|
||||
attack=attack, item=item, path=data_attack_path_
|
||||
)
|
||||
if attack_data is None:
|
||||
continue
|
||||
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
||||
pbar.update(1)
|
||||
if exp is None:
|
||||
|
Loading…
Reference in New Issue
Block a user