diff --git a/explaining_framework/metric/robust.py b/explaining_framework/metric/robust.py index f0795ab..c40509f 100644 --- a/explaining_framework/metric/robust.py +++ b/explaining_framework/metric/robust.py @@ -2,15 +2,14 @@ import copy import torch import torch.nn.functional as F +from explaining_framework.metric.base import Metric +from explaining_framework.utils.io import obj_config_to_str from torch.nn import CrossEntropyLoss, MSELoss from torch_geometric.data import Batch, Data from torch_geometric.explain.explanation import Explanation from torch_geometric.graphgym.config import cfg from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node -from explaining_framework.metric.base import Metric -from explaining_framework.utils.io import obj_config_to_str - def compute_gradient(model, inp, target, loss): with torch.autograd.set_grad_enabled(True): @@ -136,7 +135,7 @@ class Attack(Metric): def get_attacked_prediction(self, data: Data) -> Data: data_ = data.clone() data_attacked = self.forward(data_) - pred = self.get_prediction(x=data_.x, edge_index=data_.edge_index) + pred = self.get_prediction(x=data.x, edge_index=data.edge_index) pred_attacked = self.get_prediction( x=data_attacked.x, edge_index=data_attacked.edge_index ) diff --git a/main.py b/main.py index 33c6a56..6e40c1e 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,7 @@ if __name__ == "__main__": args = parse_args() outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) for attack in outline.attacks: + logging.info(f"Running {attack.__class__.__name__}: {attack.name}") for item, index in tqdm( zip(outline.dataset, outline.indexes), total=len(outline.dataset) ): @@ -41,6 +42,7 @@ if __name__ == "__main__": ) for attack in outline.attacks: + logging.info(f"Running {attack.__class__.__name__}: {attack.name}") for item, index in tqdm( zip(outline.dataset, outline.indexes), total=len(outline.dataset) ):