Reformating, fixing
This commit is contained in:
parent
7e32d6fd3a
commit
10baa1d443
|
@ -98,7 +98,7 @@ def set_cfg(explaining_cfg):
|
||||||
explaining_cfg.model_config = CN()
|
explaining_cfg.model_config = CN()
|
||||||
|
|
||||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||||
explaining_cfg.model_config.mode = None
|
explaining_cfg.model_config.mode = "regression"
|
||||||
|
|
||||||
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
# Do not modify it, will be handled by dataset , assuming one dataset = one learning task
|
||||||
explaining_cfg.model_config.task_level = None
|
explaining_cfg.model_config.task_level = None
|
||||||
|
|
|
@ -39,7 +39,7 @@ class FGSM(Metric):
|
||||||
self.zero_thresh = 10**-6
|
self.zero_thresh = 10**-6
|
||||||
|
|
||||||
def forward(self, input, target, epsilon: float) -> Explanation:
|
def forward(self, input, target, epsilon: float) -> Explanation:
|
||||||
input_ = copy.copy(input)
|
input_ = input.clone()
|
||||||
grad = compute_gradient(
|
grad = compute_gradient(
|
||||||
model=self.model, inp=input_, target=target, loss=self.loss
|
model=self.model, inp=input_, target=target, loss=self.loss
|
||||||
)
|
)
|
||||||
|
@ -168,24 +168,24 @@ class Attack(Metric):
|
||||||
def _gaussian_noise(self, exp) -> Explanation:
|
def _gaussian_noise(self, exp) -> Explanation:
|
||||||
x = torch.clone(exp.x)
|
x = torch.clone(exp.x)
|
||||||
x = x + torch.randn(*x.shape)
|
x = x + torch.randn(*x.shape)
|
||||||
exp_ = copy.copy(exp)
|
exp_ = exp.clone()
|
||||||
exp_.x = x
|
exp_.x = x
|
||||||
return exp_
|
return exp_
|
||||||
|
|
||||||
def _add_edge(self, exp, p: float) -> Explanation:
|
def _add_edge(self, exp, p: float) -> Explanation:
|
||||||
exp_ = copy.copy(exp)
|
exp_ = exp.clone()
|
||||||
exp_.edge_index, _ = add_random_edge(
|
exp_.edge_index, _ = add_random_edge(
|
||||||
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
|
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
|
||||||
)
|
)
|
||||||
return exp_
|
return exp_
|
||||||
|
|
||||||
def _remove_edge(self, exp, p: float) -> Explanation:
|
def _remove_edge(self, exp, p: float) -> Explanation:
|
||||||
exp_ = copy.copy(exp)
|
exp_ = exp.clone()
|
||||||
exp_.edge_index, _ = dropout_edge(exp_.edge_index, p=p)
|
exp_.edge_index, _ = dropout_edge(exp_.edge_index, p=p)
|
||||||
return exp_
|
return exp_
|
||||||
|
|
||||||
def _remove_node(self, exp, p: float) -> Explanation:
|
def _remove_node(self, exp, p: float) -> Explanation:
|
||||||
exp_ = copy.copy(exp)
|
exp_ = exp.clone()
|
||||||
exp_.edge_index, _, _ = dropout_node(
|
exp_.edge_index, _, _ = dropout_node(
|
||||||
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
|
exp_.edge_index, p=p, num_nodes=exp_.x.shape[0]
|
||||||
)
|
)
|
||||||
|
|
|
@ -153,7 +153,7 @@ class GraphStat(object):
|
||||||
return maps
|
return maps
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
data_ = copy.copy(data)
|
data_ = data.clone()
|
||||||
datahash = hash(data.__repr__)
|
datahash = hash(data.__repr__)
|
||||||
stats = {}
|
stats = {}
|
||||||
for k, v in self.maps.items():
|
for k, v in self.maps.items():
|
||||||
|
@ -161,7 +161,7 @@ class GraphStat(object):
|
||||||
_data_ = to_networkx(data)
|
_data_ = to_networkx(data)
|
||||||
_data_ = _data_.to_undirected()
|
_data_ = _data_.to_undirected()
|
||||||
elif k == "torch_geometric":
|
elif k == "torch_geometric":
|
||||||
_data_ = copy.copy(data)
|
_data_ = data.clone()
|
||||||
for name, func in v.items():
|
for name, func in v.items():
|
||||||
try:
|
try:
|
||||||
val = func(_data_)
|
val = func(_data_)
|
||||||
|
|
|
@ -3,6 +3,18 @@ import itertools
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
|
from scgnn.scgnn import SCGNN
|
||||||
|
from torch_geometric import seed_everything
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
from torch_geometric.loader.dataloader import DataLoader
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||||
|
@ -24,17 +36,6 @@ from explaining_framework.utils.explanation.io import (
|
||||||
explanation_verification, get_pred)
|
explanation_verification, get_pred)
|
||||||
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||||
read_json, write_json, write_yaml)
|
read_json, write_json, write_yaml)
|
||||||
from scgnn.scgnn import SCGNN
|
|
||||||
from torch_geometric import seed_everything
|
|
||||||
from torch_geometric.data import Batch, Data
|
|
||||||
from torch_geometric.explain import Explainer
|
|
||||||
from torch_geometric.explain.config import ThresholdConfig
|
|
||||||
from torch_geometric.explain.explanation import Explanation
|
|
||||||
from torch_geometric.graphgym.config import cfg
|
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
||||||
from torch_geometric.loader.dataloader import DataLoader
|
|
||||||
|
|
||||||
all__captum = [
|
all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -133,6 +134,7 @@ class ExplainingOutline(object):
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
self.load_model_info()
|
self.load_model_info()
|
||||||
self.load_cfg()
|
self.load_cfg()
|
||||||
|
self.load_cfg_to_explaining_cfg()
|
||||||
self.load_dataset()
|
self.load_dataset()
|
||||||
self.load_model()
|
self.load_model()
|
||||||
self.load_model_to_hardware()
|
self.load_model_to_hardware()
|
||||||
|
@ -165,8 +167,10 @@ class ExplainingOutline(object):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_indexes(self):
|
def load_indexes(self):
|
||||||
if not self.explaining_cfg.dataset.specific_items is None:
|
|
||||||
indexes = explaining_cfg.dataset.specific_items
|
items = self.explaining_cfg.dataset.items
|
||||||
|
if isinstance(items, (list, int)):
|
||||||
|
indexes = items
|
||||||
else:
|
else:
|
||||||
indexes = list(range(len(self.dataset)))
|
indexes = list(range(len(self.dataset)))
|
||||||
self.indexes = iter(indexes)
|
self.indexes = iter(indexes)
|
||||||
|
@ -195,15 +199,20 @@ class ExplainingOutline(object):
|
||||||
self.model_signature = info.get_model_signature()
|
self.model_signature = info.get_model_signature()
|
||||||
|
|
||||||
def load_cfg(self):
|
def load_cfg(self):
|
||||||
cfg.set_new_allowed(True)
|
|
||||||
cfg.merge_from_file(self.model_info["cfg_path"])
|
cfg.merge_from_file(self.model_info["cfg_path"])
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
def load_explaining_cfg(self):
|
def load_explaining_cfg(self):
|
||||||
explaining_cfg.set_new_allowed(True)
|
|
||||||
explaining_cfg.merge_from_file(self.explaining_cfg_path)
|
explaining_cfg.merge_from_file(self.explaining_cfg_path)
|
||||||
self.explaining_cfg = explaining_cfg
|
self.explaining_cfg = explaining_cfg
|
||||||
|
|
||||||
|
def load_cfg_to_explaining_cfg(self):
|
||||||
|
if self.cfg is None:
|
||||||
|
self.load_cfg()
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.load_explaining_cfg()
|
||||||
|
self.explaining_cfg.model_config.task_level = self.cfg.dataset.task
|
||||||
|
|
||||||
def load_explainer_cfg(self):
|
def load_explainer_cfg(self):
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
|
@ -217,11 +226,9 @@ class ExplainingOutline(object):
|
||||||
self.explainer_cfg = None
|
self.explainer_cfg = None
|
||||||
else:
|
else:
|
||||||
if self.explaining_cfg.explainer.name == "EIXGNN":
|
if self.explaining_cfg.explainer.name == "EIXGNN":
|
||||||
eixgnn_cfg.set_new_allowed(True)
|
|
||||||
eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
eixgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||||
self.explainer_cfg = eixgnn_cfg
|
self.explainer_cfg = eixgnn_cfg
|
||||||
elif self.explaining_cfg.explainer.name == "SCGNN":
|
elif self.explaining_cfg.explainer.name == "SCGNN":
|
||||||
scgnn_cfg.set_new_allowed(True)
|
|
||||||
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
scgnn_cfg.merge_from_file(self.explaining_cfg.explainer.cfg)
|
||||||
self.explainer_cfg = scgnn_cfg
|
self.explainer_cfg = scgnn_cfg
|
||||||
|
|
||||||
|
@ -245,12 +252,13 @@ class ExplainingOutline(object):
|
||||||
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
f"Expecting that the dataset to perform explanation on is the same as the model has trained on. Get {self.explaining_cfg.dataset.name} for explanation part, and {self.cfg.dataset.name} for the model."
|
||||||
)
|
)
|
||||||
self.dataset = create_dataset()
|
self.dataset = create_dataset()
|
||||||
if isinstance(self.explaining_cfg.dataset.specific_items, int):
|
items = self.explaining_cfg.dataset.items
|
||||||
ind = self.explaining_cfg.dataset.specific_items
|
print(items)
|
||||||
self.dataset = self.dataset[ind : ind + 1]
|
print(type(items))
|
||||||
elif isinstance(self.explaining_cfg.dataset.specific_items, list):
|
if isinstance(items, int):
|
||||||
ind = self.explaining_cfg.dataset.specific_items
|
self.dataset = self.dataset[items : items + 1]
|
||||||
self.dataset = self.dataset[ind]
|
elif isinstance(items, list):
|
||||||
|
self.dataset = self.dataset[items]
|
||||||
|
|
||||||
def load_dataset_to_dataloader(self, to_iter=True):
|
def load_dataset_to_dataloader(self, to_iter=True):
|
||||||
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
||||||
|
@ -308,7 +316,7 @@ class ExplainingOutline(object):
|
||||||
),
|
),
|
||||||
model_config=dict(
|
model_config=dict(
|
||||||
mode="regression",
|
mode="regression",
|
||||||
task_level=self.cfg.dataset.task,
|
task_level=self.explaining_cfg.model_config.task_level,
|
||||||
return_type=self.explaining_cfg.model_config.return_type,
|
return_type=self.explaining_cfg.model_config.return_type,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -535,8 +543,6 @@ class ExplainingOutline(object):
|
||||||
explanation = _get_explanation(self.explainer, item)
|
explanation = _get_explanation(self.explainer, item)
|
||||||
get_pred(self.explainer, explanation)
|
get_pred(self.explainer, explanation)
|
||||||
_save_explanation(explanation, path)
|
_save_explanation(explanation, path)
|
||||||
explanation = explanation.to(cfg.accelerator)
|
|
||||||
|
|
||||||
return explanation
|
return explanation
|
||||||
|
|
||||||
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
|
def get_adjust(self, adjust: Adjust, item: Explanation, path: str):
|
||||||
|
@ -549,7 +555,6 @@ class ExplainingOutline(object):
|
||||||
exp_adjust = adjust.forward(item)
|
exp_adjust = adjust.forward(item)
|
||||||
get_pred(self.explainer, exp_adjust)
|
get_pred(self.explainer, exp_adjust)
|
||||||
_save_explanation(exp_adjust, path)
|
_save_explanation(exp_adjust, path)
|
||||||
exp_adjust = exp_adjust.to(cfg.accelerator)
|
|
||||||
return exp_adjust
|
return exp_adjust
|
||||||
|
|
||||||
def get_threshold(self, item: Explanation, path: str):
|
def get_threshold(self, item: Explanation, path: str):
|
||||||
|
@ -562,7 +567,6 @@ class ExplainingOutline(object):
|
||||||
exp_threshold = self.explainer._post_process(item)
|
exp_threshold = self.explainer._post_process(item)
|
||||||
get_pred(self.explainer, exp_threshold)
|
get_pred(self.explainer, exp_threshold)
|
||||||
_save_explanation(exp_threshold, path)
|
_save_explanation(exp_threshold, path)
|
||||||
exp_threshold = exp_threshold.to(cfg.accelerator)
|
|
||||||
return exp_threshold
|
return exp_threshold
|
||||||
|
|
||||||
def get_metric(self, metric: Metric, item: Explanation, path: str):
|
def get_metric(self, metric: Metric, item: Explanation, path: str):
|
||||||
|
|
|
@ -14,7 +14,7 @@ class Adjust(object):
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
|
|
||||||
def forward(self, exp: Explanation) -> Explanation:
|
def forward(self, exp: Explanation) -> Explanation:
|
||||||
exp_ = copy.copy(exp)
|
exp_ = exp.clone()
|
||||||
_store = exp_.to_dict()
|
_store = exp_.to_dict()
|
||||||
for k, v in _store.items():
|
for k, v in _store.items():
|
||||||
if "mask" in k:
|
if "mask" in k:
|
||||||
|
@ -41,7 +41,7 @@ class Adjust(object):
|
||||||
return mask_
|
return mask_
|
||||||
|
|
||||||
def normalize(self, mask: FloatTensor) -> FloatTensor:
|
def normalize(self, mask: FloatTensor) -> FloatTensor:
|
||||||
norm = torch.norm(mask, p="inf")
|
norm = torch.norm(mask, p=float("inf"))
|
||||||
if norm.item() > 0:
|
if norm.item() > 0:
|
||||||
mask_ = mask / norm.item()
|
mask_ = mask / norm.item()
|
||||||
return mask_
|
return mask_
|
||||||
|
|
|
@ -5,6 +5,7 @@ import os
|
||||||
import torch
|
import torch
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
|
||||||
|
|
||||||
def _get_explanation(explainer, item):
|
def _get_explanation(explainer, item):
|
||||||
|
@ -18,6 +19,7 @@ def _get_explanation(explainer, item):
|
||||||
# WARNING + LOG
|
# WARNING + LOG
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
|
explanation = explanation.to(cfg.accelerator)
|
||||||
return explanation
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,9 +49,8 @@ def explanation_verification(exp: Explanation) -> bool:
|
||||||
for mask in masks:
|
for mask in masks:
|
||||||
is_nan = mask.isnan().any().item()
|
is_nan = mask.isnan().any().item()
|
||||||
is_inf = mask.isinf().any().item()
|
is_inf = mask.isinf().any().item()
|
||||||
is_const = mask.max() == mask.min()
|
|
||||||
is_ok = exp.validate()
|
is_ok = exp.validate()
|
||||||
if is_nan or is_inf or not is_ok or is_const:
|
if is_nan or is_inf or not is_ok:
|
||||||
is_good = False
|
is_good = False
|
||||||
return is_good
|
return is_good
|
||||||
else:
|
else:
|
||||||
|
@ -58,7 +59,7 @@ def explanation_verification(exp: Explanation) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def _save_explanation(exp: Explanation, path: str) -> None:
|
def _save_explanation(exp: Explanation, path: str) -> None:
|
||||||
data = copy.copy(exp).to_dict()
|
data = exp.clone().to_dict()
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
data[k] = v.detach().cpu().tolist()
|
data[k] = v.detach().cpu().tolist()
|
||||||
|
@ -77,5 +78,3 @@ def _load_explanation(path: str) -> Explanation:
|
||||||
else:
|
else:
|
||||||
data[k] = torch.FloatTensor(v)
|
data[k] = torch.FloatTensor(v)
|
||||||
return Explanation.from_dict(data)
|
return Explanation.from_dict(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
5
main.py
5
main.py
|
@ -27,6 +27,7 @@ from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||||
|
print(outline.explaining_cfg)
|
||||||
|
|
||||||
out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature)
|
out_dir = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature)
|
||||||
makedirs(out_dir)
|
makedirs(out_dir)
|
||||||
|
@ -64,14 +65,14 @@ if __name__ == "__main__":
|
||||||
for adjust in outline.adjusts:
|
for adjust in outline.adjusts:
|
||||||
adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}")
|
adjust_path = os.path.join(raw_path, f"adjust-{obj_config_to_str(adjust)}")
|
||||||
makedirs(adjust_path)
|
makedirs(adjust_path)
|
||||||
exp_adjust_path = os.path.join(exp_adjust_path, f"{index}.json")
|
exp_adjust_path = os.path.join(adjust_path, f"{index}.json")
|
||||||
exp_adjust = outline.get_adjust(
|
exp_adjust = outline.get_adjust(
|
||||||
adjust=adjust, item=raw_exp, path=exp_adjust_path
|
adjust=adjust, item=raw_exp, path=exp_adjust_path
|
||||||
)
|
)
|
||||||
for threshold_conf in outline.thresholds_configs:
|
for threshold_conf in outline.thresholds_configs:
|
||||||
outline.set_explainer_threshold_config(threshold_conf)
|
outline.set_explainer_threshold_config(threshold_conf)
|
||||||
masking_path = os.path.join(
|
masking_path = os.path.join(
|
||||||
save_raw_path_,
|
adjust_path,
|
||||||
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
|
"-".join([f"{k}={v}" for k, v in threshold_conf.items()]),
|
||||||
)
|
)
|
||||||
makedirs(masking_path)
|
makedirs(masking_path)
|
||||||
|
|
Loading…
Reference in New Issue