Fixings Attack pred score bug,it was a variable reference issue
This commit is contained in:
parent
3372f81576
commit
ad7fe83916
|
@ -2,15 +2,14 @@ import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from explaining_framework.metric.base import Metric
|
||||||
|
from explaining_framework.utils.io import obj_config_to_str
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
from torch_geometric.data import Batch, Data
|
from torch_geometric.data import Batch, Data
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
from torch_geometric.graphgym.config import cfg
|
from torch_geometric.graphgym.config import cfg
|
||||||
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
from torch_geometric.utils import add_random_edge, dropout_edge, dropout_node
|
||||||
|
|
||||||
from explaining_framework.metric.base import Metric
|
|
||||||
from explaining_framework.utils.io import obj_config_to_str
|
|
||||||
|
|
||||||
|
|
||||||
def compute_gradient(model, inp, target, loss):
|
def compute_gradient(model, inp, target, loss):
|
||||||
with torch.autograd.set_grad_enabled(True):
|
with torch.autograd.set_grad_enabled(True):
|
||||||
|
@ -136,7 +135,7 @@ class Attack(Metric):
|
||||||
def get_attacked_prediction(self, data: Data) -> Data:
|
def get_attacked_prediction(self, data: Data) -> Data:
|
||||||
data_ = data.clone()
|
data_ = data.clone()
|
||||||
data_attacked = self.forward(data_)
|
data_attacked = self.forward(data_)
|
||||||
pred = self.get_prediction(x=data_.x, edge_index=data_.edge_index)
|
pred = self.get_prediction(x=data.x, edge_index=data.edge_index)
|
||||||
pred_attacked = self.get_prediction(
|
pred_attacked = self.get_prediction(
|
||||||
x=data_attacked.x, edge_index=data_attacked.edge_index
|
x=data_attacked.x, edge_index=data_attacked.edge_index
|
||||||
)
|
)
|
||||||
|
|
2
main.py
2
main.py
|
@ -27,6 +27,7 @@ if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
outline = ExplainingOutline(args.explaining_cfg_file, args.gpu_id)
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
|
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
|
||||||
for item, index in tqdm(
|
for item, index in tqdm(
|
||||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||||
):
|
):
|
||||||
|
@ -41,6 +42,7 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
for attack in outline.attacks:
|
for attack in outline.attacks:
|
||||||
|
logging.info(f"Running {attack.__class__.__name__}: {attack.name}")
|
||||||
for item, index in tqdm(
|
for item, index in tqdm(
|
||||||
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
zip(outline.dataset, outline.indexes), total=len(outline.dataset)
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue