Fixings Attack pred score bug,it was a variable reference issue

This commit is contained in:
araison 2023-01-11 20:46:24 +01:00
parent 3372f81576
commit ad7fe83916
2 changed files with 5 additions and 4 deletions

View File

@ -2,15 +2,14 @@ import copy
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from explaining_framework.metric.base import Metric
from explaining_framework.utils.io import obj_config_to_str
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from torch_geometric.data import Batch, Data from torch_geometric.data import Batch, Data
from torch_geometric.explain.explanation import Explanation from torch_geometric.explain.explanation import Explanation
from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.config import cfg
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
from explaining_framework.metric.base import Metric
from explaining_framework.utils.io import obj_config_to_str
def compute_gradient(model, inp, target, loss): def compute_gradient(model, inp, target, loss):
with torch.autograd.set_grad_enabled(True): with torch.autograd.set_grad_enabled(True):
@ -136,7 +135,7 @@ class Attack(Metric):
def get_attacked_prediction(self, data: Data) -> Data: def get_attacked_prediction(self, data: Data) -> Data:
data_ = data.clone() data_ = data.clone()
data_attacked = self.forward(data_) data_attacked = self.forward(data_)
pred = self.get_prediction(x=data_.x, edge_index=data_.edge_index) pred = self.get_prediction(x=data.x, edge_index=data.edge_index)
pred_attacked = self.get_prediction( pred_attacked = self.get_prediction(
x=data_attacked.x, edge_index=data_attacked.edge_index x=data_attacked.x, edge_index=data_attacked.edge_index
) )

View File

@ -27,6 +27,7 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id) outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
for attack in outline.attacks: for attack in outline.attacks:
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
for item, index in tqdm( for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset) zip(outline.dataset, outline.indexes), total=len(outline.dataset)
): ):
@ -41,6 +42,7 @@ if __name__ == "__main__":
) )
for attack in outline.attacks: for attack in outline.attacks:
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
for item, index in tqdm( for item, index in tqdm(
zip(outline.dataset, outline.indexes), total=len(outline.dataset) zip(outline.dataset, outline.indexes), total=len(outline.dataset)
): ):