Fixing many bugs
This commit is contained in:
parent
b7fb522255
commit
3d0d3ec451
|
@ -19,57 +19,57 @@ if "__main__" == __name__:
|
||||||
|
|
||||||
DATASET = [
|
DATASET = [
|
||||||
"CIFAR10",
|
"CIFAR10",
|
||||||
"TRIANGLES",
|
# "TRIANGLES",
|
||||||
"COLORS-3",
|
# "COLORS-3",
|
||||||
"REDDIT-BINARY",
|
# "REDDIT-BINARY",
|
||||||
"REDDIT-MULTI-5K",
|
# "REDDIT-MULTI-5K",
|
||||||
"REDDIT-MULTI-12K",
|
# "REDDIT-MULTI-12K",
|
||||||
"COLLAB",
|
# "COLLAB",
|
||||||
"DBLP_v1",
|
# "DBLP_v1",
|
||||||
"COIL-DEL",
|
# "COIL-DEL",
|
||||||
"COIL-RAG",
|
# "COIL-RAG",
|
||||||
"Fingerprint",
|
# "Fingerprint",
|
||||||
"Letter-high",
|
# "Letter-high",
|
||||||
"Letter-low",
|
# "Letter-low",
|
||||||
"Letter-med",
|
# "Letter-med",
|
||||||
"MSRC_9",
|
"MSRC_9",
|
||||||
"MSRC_21",
|
# "MSRC_21",
|
||||||
"MSRC_21C",
|
"MSRC_21C",
|
||||||
"DD",
|
# "DD",
|
||||||
"ENZYMES",
|
# "ENZYMES",
|
||||||
"PROTEINS",
|
"PROTEINS",
|
||||||
"QM9",
|
# "QM9",
|
||||||
"MUTAG",
|
# "MUTAG",
|
||||||
"Mutagenicity",
|
# "Mutagenicity",
|
||||||
"AIDS",
|
# "AIDS",
|
||||||
"PATTERN",
|
# "PATTERN",
|
||||||
"CLUSTER",
|
# "CLUSTER",
|
||||||
"MNIST",
|
"MNIST",
|
||||||
"CIFAR10",
|
"CIFAR10",
|
||||||
"TSP",
|
# "TSP",
|
||||||
"CSL",
|
# "CSL",
|
||||||
"KarateClub",
|
# "KarateClub",
|
||||||
"CS",
|
# "CS",
|
||||||
"Physics",
|
# "Physics",
|
||||||
"BBBP",
|
# "BBBP",
|
||||||
"Tox21",
|
# "Tox21",
|
||||||
"HIV",
|
# "HIV",
|
||||||
"PCBA",
|
# "PCBA",
|
||||||
"MUV",
|
# "MUV",
|
||||||
"BACE",
|
# "BACE",
|
||||||
"SIDER",
|
# "SIDER",
|
||||||
"ClinTox",
|
# "ClinTox",
|
||||||
"AIFB",
|
# "AIFB",
|
||||||
"AM",
|
# "AM",
|
||||||
"MUTAG",
|
# "MUTAG",
|
||||||
"BGS",
|
# "BGS",
|
||||||
"FAUST",
|
# "FAUST",
|
||||||
"DynamicFAUST",
|
# "DynamicFAUST",
|
||||||
"ShapeNet",
|
# "ShapeNet",
|
||||||
"ModelNet10",
|
# "ModelNet10",
|
||||||
"ModelNet40",
|
# "ModelNet40",
|
||||||
"PascalVOC-SP",
|
# "PascalVOC-SP",
|
||||||
"COCO-SP",
|
# "COCO-SP",
|
||||||
]
|
]
|
||||||
EXPLAINER = [
|
EXPLAINER = [
|
||||||
"CAM",
|
"CAM",
|
||||||
|
@ -81,8 +81,8 @@ if "__main__" == __name__:
|
||||||
# "PGExplainer",
|
# "PGExplainer",
|
||||||
"PGMExplainer",
|
"PGMExplainer",
|
||||||
"RandomExplainer",
|
"RandomExplainer",
|
||||||
"SubgraphX",
|
# "SubgraphX",
|
||||||
"GraphMASK",
|
# "GraphMASK",
|
||||||
"GNNExplainer",
|
"GNNExplainer",
|
||||||
"EIXGNN",
|
"EIXGNN",
|
||||||
"SCGNN",
|
"SCGNN",
|
||||||
|
@ -113,7 +113,9 @@ if "__main__" == __name__:
|
||||||
explaining_cfg["model"] = {}
|
explaining_cfg["model"] = {}
|
||||||
explaining_cfg["model"]["ckpt"] = string_to_python(model_kind)
|
explaining_cfg["model"]["ckpt"] = string_to_python(model_kind)
|
||||||
explaining_cfg["model"]["path"] = string_to_python(
|
explaining_cfg["model"]["path"] = string_to_python(
|
||||||
"/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
# "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
||||||
|
"/home/SIC/araison/exps/pytorch_geometric/graphgym/results/"
|
||||||
|
# "/media/data/SIC/araison/exps/pyg_fork/graphgym/results/graph_classif_base_grid_graph_classif_grid"
|
||||||
)
|
)
|
||||||
# explaining_cfg['out_dir']='./explanation'
|
# explaining_cfg['out_dir']='./explanation'
|
||||||
# explaining_cfg['print']='both'
|
# explaining_cfg['print']='both'
|
||||||
|
|
|
@ -183,7 +183,6 @@ class ExplainingOutline(object):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_indexes(self):
|
def load_indexes(self):
|
||||||
|
|
||||||
item = self.explaining_cfg.dataset.item
|
item = self.explaining_cfg.dataset.item
|
||||||
if isinstance(item, (list, int)):
|
if isinstance(item, (list, int)):
|
||||||
indexes = item
|
indexes = item
|
||||||
|
@ -566,17 +565,22 @@ class ExplainingOutline(object):
|
||||||
def get_explanation(self, item: Data, path: str):
|
def get_explanation(self, item: Data, path: str):
|
||||||
if is_exists(path):
|
if is_exists(path):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
explanation = _get_explanation(self.explainer, item)
|
try:
|
||||||
if explanation is None:
|
explanation = _get_explanation(self.explainer, item)
|
||||||
logging.warning(
|
if explanation is None:
|
||||||
" EXP::Generated; Path %s; FAILED",
|
logging.error(
|
||||||
(path),
|
" EXP::Generated; Path %s; FAILED",
|
||||||
)
|
(path),
|
||||||
else:
|
)
|
||||||
logging.debug(
|
|
||||||
"EXP::Generated; Path %s; SUCCEEDED",
|
else:
|
||||||
(path),
|
logging.debug(
|
||||||
)
|
"EXP::Generated; Path %s; SUCCEEDED",
|
||||||
|
(path),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(str(e))
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
explanation = _load_explanation(path)
|
explanation = _load_explanation(path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
|
@ -585,13 +589,17 @@ class ExplainingOutline(object):
|
||||||
)
|
)
|
||||||
explanation = explanation.to(self.cfg.accelerator)
|
explanation = explanation.to(self.cfg.accelerator)
|
||||||
else:
|
else:
|
||||||
explanation = _get_explanation(self.explainer, item)
|
try:
|
||||||
get_pred(self.explainer, explanation)
|
explanation = _get_explanation(self.explainer, item)
|
||||||
_save_explanation(explanation, path)
|
get_pred(self.explainer, explanation)
|
||||||
logging.debug(
|
_save_explanation(explanation, path)
|
||||||
"EXP::Generated; Path %s; SUCCEEDED",
|
logging.debug(
|
||||||
(path),
|
"EXP::Generated; Path %s; SUCCEEDED",
|
||||||
)
|
(path),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
return explanation
|
return explanation
|
||||||
|
|
||||||
|
@ -694,12 +702,15 @@ class ExplainingOutline(object):
|
||||||
def get_attack(self, attack: Attack, item: Data, path: str):
|
def get_attack(self, attack: Attack, item: Data, path: str):
|
||||||
if is_exists(path):
|
if is_exists(path):
|
||||||
if self.explaining_cfg.explainer.force:
|
if self.explaining_cfg.explainer.force:
|
||||||
data_attack = attack.get_attacked_prediction(item)
|
try:
|
||||||
logging.debug(
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
logging.debug(
|
||||||
(path),
|
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||||
)
|
(path),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(str(e))
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
data_attack = _load_explanation(path)
|
data_attack = _load_explanation(path)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
|
@ -707,12 +718,17 @@ class ExplainingOutline(object):
|
||||||
(path),
|
(path),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
data_attack = attack.get_attacked_prediction(item)
|
try:
|
||||||
_save_explanation(data_attack, path)
|
data_attack = attack.get_attacked_prediction(item)
|
||||||
logging.debug(
|
_save_explanation(data_attack, path)
|
||||||
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
logging.debug(
|
||||||
(path),
|
"ATTACK::Generated %s; Path %s; SUCCEEDED",
|
||||||
)
|
(path),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
return data_attack
|
return data_attack
|
||||||
|
|
||||||
def setup_experiment(self):
|
def setup_experiment(self):
|
||||||
|
|
4
main.py
4
main.py
|
@ -40,6 +40,8 @@ if __name__ == "__main__":
|
||||||
data_attack = outline.get_attack(
|
data_attack = outline.get_attack(
|
||||||
attack=attack, item=item, path=data_attack_path
|
attack=attack, item=item, path=data_attack_path
|
||||||
)
|
)
|
||||||
|
if data_attack is None:
|
||||||
|
continue
|
||||||
|
|
||||||
item, index = outline.get_item()
|
item, index = outline.get_item()
|
||||||
|
|
||||||
|
@ -57,6 +59,8 @@ if __name__ == "__main__":
|
||||||
attack_data = outline.get_attack(
|
attack_data = outline.get_attack(
|
||||||
attack=attack, item=item, path=data_attack_path_
|
attack=attack, item=item, path=data_attack_path_
|
||||||
)
|
)
|
||||||
|
if attack_data is None:
|
||||||
|
continue
|
||||||
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
exp = outline.get_explanation(item=attack_data, path=data_attack_path_)
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
if exp is None:
|
if exp is None:
|
||||||
|
|
Loading…
Reference in New Issue