Fixings Attack pred score bug,it was a variable reference issue
This commit is contained in:
parent
3372f81576
commit
ad7fe83916
@ -2,15 +2,14 @@ import copy
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from explaining_framework.metric.base import Metric
|
||||
from explaining_framework.utils.io import obj_config_to_str
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch_geometric.data import Batch, Data
|
||||
from torch_geometric.explain.explanation import Explanation
|
||||
from torch_geometric.graphgym.config import cfg
|
||||
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
||||
|
||||
from explaining_framework.metric.base import Metric
|
||||
from explaining_framework.utils.io import obj_config_to_str
|
||||
|
||||
|
||||
def compute_gradient(model, inp, target, loss):
|
||||
with torch.autograd.set_grad_enabled(True):
|
||||
@ -136,7 +135,7 @@ class Attack(Metric):
|
||||
def get_attacked_prediction(self, data: Data) -> Data:
|
||||
data_ = data.clone()
|
||||
data_attacked = self.forward(data_)
|
||||
pred = self.get_prediction(x=data_.x, edge_index=data_.edge_index)
|
||||
pred = self.get_prediction(x=data.x, edge_index=data.edge_index)
|
||||
pred_attacked = self.get_prediction(
|
||||
x=data_attacked.x, edge_index=data_attacked.edge_index
|
||||
)
|
||||
|
2
main.py
2
main.py
@ -27,6 +27,7 @@ if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||
for attack in outline.attacks:
|
||||
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
|
||||
for item, index in tqdm(
|
||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||
):
|
||||
@ -41,6 +42,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
for attack in outline.attacks:
|
||||
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
|
||||
for item, index in tqdm(
|
||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user