Reformating
This commit is contained in:
parent
68449ad678
commit
fb012ad723
|
@ -1,16 +1,8 @@
|
||||||
import copy
|
import copy
|
||||||
|
import itertools
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from eixgnn.eixgnn import EiXGNN
|
from eixgnn.eixgnn import EiXGNN
|
||||||
from scgnn.scgnn import SCGNN
|
|
||||||
from torch_geometric.data import Batch, Data
|
|
||||||
from torch_geometric.explain import Explainer
|
|
||||||
from torch_geometric.graphgym.config import cfg
|
|
||||||
from torch_geometric.graphgym.loader import create_dataset
|
|
||||||
from torch_geometric.graphgym.model_builder import cfg, create_model
|
|
||||||
from torch_geometric.graphgym.utils.device import auto_select_device
|
|
||||||
from torch_geometric.loader.dataloader import DataLoader
|
|
||||||
|
|
||||||
from explaining_framework.config.explainer_config.eixgnn_config import \
|
from explaining_framework.config.explainer_config.eixgnn_config import \
|
||||||
eixgnn_cfg
|
eixgnn_cfg
|
||||||
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
|
||||||
|
@ -22,8 +14,19 @@ from explaining_framework.metric.accuracy import Accuracy
|
||||||
from explaining_framework.metric.fidelity import Fidelity
|
from explaining_framework.metric.fidelity import Fidelity
|
||||||
from explaining_framework.metric.robust import Attack
|
from explaining_framework.metric.robust import Attack
|
||||||
from explaining_framework.metric.sparsity import Sparsity
|
from explaining_framework.metric.sparsity import Sparsity
|
||||||
|
from explaining_framework.stats.graph.graph_stat import GraphStat
|
||||||
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
|
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
|
||||||
_load_ckpt)
|
_load_ckpt)
|
||||||
|
from explaining_framework.utils.explanation.adjust import Adjust
|
||||||
|
from scgnn.scgnn import SCGNN
|
||||||
|
from torch_geometric.data import Batch, Data
|
||||||
|
from torch_geometric.explain import Explainer
|
||||||
|
from torch_geometric.explain.config import ThresholdConfig
|
||||||
|
from torch_geometric.graphgym.config import cfg
|
||||||
|
from torch_geometric.graphgym.loader import create_dataset
|
||||||
|
from torch_geometric.graphgym.model_builder import cfg, create_model
|
||||||
|
from torch_geometric.graphgym.utils.device import auto_select_device
|
||||||
|
from torch_geometric.loader.dataloader import DataLoader
|
||||||
|
|
||||||
all__captum = [
|
all__captum = [
|
||||||
"LRP",
|
"LRP",
|
||||||
|
@ -85,6 +88,10 @@ all_robust = [
|
||||||
]
|
]
|
||||||
all_sparsity = ["l0"]
|
all_sparsity = ["l0"]
|
||||||
|
|
||||||
|
adjust_pattern = 'ranp'
|
||||||
|
all_adjusts_filters = [''.join(filters) for i in range(len(adjust_pattern)+1)for filters in itertools.permutations(adjust_pattern,i)]
|
||||||
|
|
||||||
|
all_threshold_type = ['topk_hard','hard','topk']
|
||||||
|
|
||||||
class ExplainingOutline(object):
|
class ExplainingOutline(object):
|
||||||
def __init__(self, explaining_cfg_path: str):
|
def __init__(self, explaining_cfg_path: str):
|
||||||
|
@ -100,17 +107,65 @@ class ExplainingOutline(object):
|
||||||
self.metrics = None
|
self.metrics = None
|
||||||
self.attacks = None
|
self.attacks = None
|
||||||
self.model_signature = None
|
self.model_signature = None
|
||||||
|
self.indexes = None
|
||||||
|
self.explaining_algorithm = None
|
||||||
|
self.explainer = None
|
||||||
|
self.adjusts = None
|
||||||
|
self.thresholds_configs = None
|
||||||
|
self.graphstat = None
|
||||||
|
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
self.load_model_info()
|
self.load_model_info()
|
||||||
self.load_cfg()
|
self.load_cfg()
|
||||||
self.load_dataset()
|
self.load_dataset()
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
self.load_model_to_hardware()
|
||||||
self.load_explainer_cfg()
|
self.load_explainer_cfg()
|
||||||
|
self.load_explaining_algorithm()
|
||||||
self.load_explainer()
|
self.load_explainer()
|
||||||
self.load_metric()
|
self.load_metric()
|
||||||
self.load_attack()
|
self.load_attack()
|
||||||
self.load_dataset_to_dataloader()
|
self.load_dataset_to_dataloader()
|
||||||
|
self.load_indexes()
|
||||||
|
self.load_adjust()
|
||||||
|
self.load_threshold()
|
||||||
|
self.load_graphstat()
|
||||||
|
|
||||||
|
def load_model_to_hardware(self):
|
||||||
|
auto_select_device()
|
||||||
|
device = self.cfg.accelerator
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
if self.dataset is None:
|
||||||
|
self.load_dataset()
|
||||||
|
try:
|
||||||
|
item = next(self.dataset)
|
||||||
|
item = item.to(cfg.accelerator)
|
||||||
|
return item
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_indexes(self):
|
||||||
|
if not self.explaining_cfg.dataset.specific_items is None:
|
||||||
|
indexes = explaining_cfg.dataset.specific_items
|
||||||
|
else:
|
||||||
|
indexes = list(range(len(self.dataset)))
|
||||||
|
self.indexes = iter(indexes)
|
||||||
|
|
||||||
|
def get_index(self):
|
||||||
|
if self.indexes is None:
|
||||||
|
self.load_indexes()
|
||||||
|
try:
|
||||||
|
item = next(self.indexes)
|
||||||
|
return item
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_item(self):
|
||||||
|
item = self.get_data()
|
||||||
|
index = self.get_index()
|
||||||
|
return item, index
|
||||||
|
|
||||||
def load_model_info(self):
|
def load_model_info(self):
|
||||||
info = LoadModelInfo(
|
info = LoadModelInfo(
|
||||||
|
@ -160,6 +215,7 @@ class ExplainingOutline(object):
|
||||||
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
|
self.model = _load_ckpt(self.model, self.model_info["ckpt_path"])
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise ValueError("Model ckpt has not been loaded, ckpt file not found")
|
raise ValueError("Model ckpt has not been loaded, ckpt file not found")
|
||||||
|
self.model = self.model.eval()
|
||||||
|
|
||||||
def load_dataset(self):
|
def load_dataset(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
|
@ -181,7 +237,7 @@ class ExplainingOutline(object):
|
||||||
def load_dataset_to_dataloader(self):
|
def load_dataset_to_dataloader(self):
|
||||||
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
self.dataset = DataLoader(dataset=self.dataset, shuffle=False, batch_size=1)
|
||||||
|
|
||||||
def load_explainer(self):
|
def load_explaining_algorithm(self):
|
||||||
self.load_explainer_cfg()
|
self.load_explainer_cfg()
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
@ -219,54 +275,216 @@ class ExplainingOutline(object):
|
||||||
raise ValueError(f"{name_} Metric is not supported yet")
|
raise ValueError(f"{name_} Metric is not supported yet")
|
||||||
self.explaining_algorithm = explaining_algorithm
|
self.explaining_algorithm = explaining_algorithm
|
||||||
|
|
||||||
|
def load_explainer(self):
|
||||||
|
if self.explaining_algorithm is None:
|
||||||
|
self.load_explaining_algorithm()
|
||||||
|
explainer = Explainer(
|
||||||
|
model=self.model,
|
||||||
|
algorithm=self.explaining_algorithm,
|
||||||
|
explainer_config=dict(
|
||||||
|
explanation_type=self.explaining_cfg.explanation_type,
|
||||||
|
node_mask_type="object",
|
||||||
|
edge_mask_type="object",
|
||||||
|
),
|
||||||
|
model_config=dict(
|
||||||
|
mode="regression",
|
||||||
|
task_level=self.cfg.dataset.task,
|
||||||
|
return_type=self.explaining_cfg.model_config.return_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.explainer = explainer
|
||||||
|
|
||||||
|
def load_fidelity(self):
|
||||||
|
if self.cfg is None:
|
||||||
|
self.load_cfg()
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.load_explaining_cfg()
|
||||||
|
name = self.explaining_cfg.metrics.fidelity.name
|
||||||
|
if name == 'all':
|
||||||
|
all_metrics = [
|
||||||
|
Fidelity(name=name, model=self.model) for name in all_fidelity
|
||||||
|
]
|
||||||
|
elif isinstance(name,str):
|
||||||
|
if name in all_fidelity:
|
||||||
|
all_metrics = [Fidelity(name=name, model=self.model)]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'This fidelity metric {name} is nor supported yet. Supported are {all_fidelity}')
|
||||||
|
elif isinstance(name,list):
|
||||||
|
all_metrics = [Fidelity(name=name, model=self.model) for name_ in name if name_ in all_fidelity]
|
||||||
|
elif name is None:
|
||||||
|
all_metrics = []
|
||||||
|
self.fidelities = all_metrics
|
||||||
|
|
||||||
|
def load_sparsity(self):
|
||||||
|
if self.cfg is None:
|
||||||
|
self.load_cfg()
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.load_explaining_cfg()
|
||||||
|
name = self.explaining_cfg.metrics.sparsity.name
|
||||||
|
if name == 'all':
|
||||||
|
all_metrics = [
|
||||||
|
Sparsity(name=name) for name in all_sparsity
|
||||||
|
]
|
||||||
|
elif isinstance(name,str):
|
||||||
|
if name in all_sparsity:
|
||||||
|
all_metrics = [Sparsity(name=name)]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'This sparsity metric {name} is nor supported yet. Supported are {all_sparsity}')
|
||||||
|
elif isinstance(name,list):
|
||||||
|
all_metrics = [Sparsity(name=name) for name_ in name if name_ in all_sparsity]
|
||||||
|
elif name is None:
|
||||||
|
all_metrics = []
|
||||||
|
self.sparsities = all_metrics
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_accuracy(self):
|
||||||
|
if self.cfg is None:
|
||||||
|
self.load_cfg()
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.load_explaining_cfg()
|
||||||
|
|
||||||
|
if self.explaining_cfg.dataset.name == "BASHAPES":
|
||||||
|
name = self.explaining_cfg.metrics.accuracy.name
|
||||||
|
if name == 'all':
|
||||||
|
all_metrics = [
|
||||||
|
Accuracy(name=name) for name in all_accuracy
|
||||||
|
]
|
||||||
|
elif isinstance(name,str):
|
||||||
|
if name in all_accuracy:
|
||||||
|
all_metrics = [Accuracy(name=name)]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'This accuracy metric {name} is nor supported yet. Supported are {all_accuracy}')
|
||||||
|
elif isinstance(name,list):
|
||||||
|
all_metrics = [Accuracy(name=name) for name_ in name if name_ in all_accuracy]
|
||||||
|
elif name is None:
|
||||||
|
all_metrics = []
|
||||||
|
self.accuraties = all_metrics
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_metric(self):
|
def load_metric(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
self.load_cfg()
|
self.load_cfg()
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
|
if self.accuraties is None:
|
||||||
|
self.load_accuracy()
|
||||||
|
if self.sparsities is None:
|
||||||
|
self.load_sparsity()
|
||||||
|
if self.fidelities is None:
|
||||||
|
self.load_fidelity()
|
||||||
|
|
||||||
|
self.metrics = self.fidelities+self.accuraties+self.sparsities
|
||||||
|
|
||||||
name_ = self.explaining_cfg.metrics.name
|
|
||||||
|
|
||||||
if name_ == "all":
|
|
||||||
all_fid_metrics = [
|
|
||||||
Fidelity(name=name, model=self.model) for name in all_fidelity
|
|
||||||
]
|
|
||||||
all_spa_metrics = [Sparsity(name) for name in all_sparsity]
|
|
||||||
self.metrics = all_spa_metrics + all_fid_metrics
|
|
||||||
|
|
||||||
if self.explaining_cfg.dataset.name == "BASHAPES":
|
|
||||||
all_acc_metrics = [Accuracy(name) for name in all_accuracy]
|
|
||||||
self.metrics = self.metrics + all_acc_metrics
|
|
||||||
elif name_ in all_fidelity:
|
|
||||||
self.metrics = [Fidelity(name=name_, model=self.model)]
|
|
||||||
elif name_ in all_sparsity:
|
|
||||||
self.metrics = [Sparsity(name_)]
|
|
||||||
elif name_ in all_accuracy:
|
|
||||||
if self.explaining_cfg.dataset.name == "BASHAPES":
|
|
||||||
self.metrics = [Accuracy(name_)]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"The metric {name} is not supported for dataset {self.explaining_cfg.dataset.name} yet, it requires groundtruth explanation"
|
|
||||||
)
|
|
||||||
elif name_ is None:
|
|
||||||
self.metrics = []
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{name_} Metric is not supported yet")
|
|
||||||
|
|
||||||
def load_attack(self):
|
def load_attack(self):
|
||||||
if self.cfg is None:
|
if self.cfg is None:
|
||||||
self.load_cfg()
|
self.load_cfg()
|
||||||
if self.explaining_cfg is None:
|
if self.explaining_cfg is None:
|
||||||
self.load_explaining_cfg()
|
self.load_explaining_cfg()
|
||||||
name_ = self.explaining_cfg.attack.name
|
name = self.explaining_cfg.attack.name
|
||||||
if name_ == "all":
|
if name == 'all':
|
||||||
all_rob_metrics = [
|
all_metrics = [
|
||||||
Attack(name=name, model=self.model) for name in all_robust
|
Attack(name=name,model=self.model) for name in all_robust
|
||||||
]
|
]
|
||||||
self.attacks = all_rob_metrics
|
elif isinstance(name,str):
|
||||||
elif name_ in all_robust:
|
if name in all_robust:
|
||||||
self.attacks = [Attack(name=name_, model=self.model)]
|
all_metrics = [Attack(name=name,model=self.model)]
|
||||||
elif name_ is None:
|
else:
|
||||||
self.attacks = []
|
raise ValueError(f'This Attack metric {name} is not supported yet. Supported are {all_robust}')
|
||||||
|
elif isinstance(name,list):
|
||||||
|
all_metrics = [Attack(name=name,model=self.model) for name_ in name if name_ in all_robust]
|
||||||
|
elif name is None:
|
||||||
|
all_metrics = []
|
||||||
|
self.attacks = all_metrics
|
||||||
|
|
||||||
|
def load_adjust(self):
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.load_explaining_cfg()
|
||||||
|
strategy = self.explaining_cfg.adjust.strategy
|
||||||
|
if strategy == "all":
|
||||||
|
self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters]
|
||||||
|
elif isinstance(name,str):
|
||||||
|
if name in all_adjusts_filters:
|
||||||
|
all_metrics = [Adjust(strategy=name)]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}')
|
||||||
|
elif isinstance(name,list):
|
||||||
|
all_metrics = [Adjust(strategy=name_) for name_ in name if name_ in all_robust]
|
||||||
|
elif name is None:
|
||||||
|
all_metrics = []
|
||||||
|
self.adjusts = all_metrics
|
||||||
|
|
||||||
|
def load_threshold(self):
|
||||||
|
if self.explaining_cfg is None:
|
||||||
|
self.load_explaining_cfg()
|
||||||
|
threshold_type =self.explaining_cfg.threshold_config.type
|
||||||
|
if threshold_type == 'all':
|
||||||
|
th_hard = [{"threshold_type": 'hard',"value": th_value} for th_value in self.explaining_cfg.threshold.value.hard]
|
||||||
|
th_topk = [{"threshold_type": th_type,"value": th_value} for th_value in self.explaining_cfg.threshold.value.topk f or th_type in all_threshold_type if 'topk' in th_type]
|
||||||
|
all_threshold = th_hard + th_topk
|
||||||
|
elif isinstance(threshold_type,str):
|
||||||
|
if threshold_type in all_threshold_type:
|
||||||
|
if 'topk' in threshold_type:
|
||||||
|
all_threshold = [{
|
||||||
|
"threshold_type": threshold_type,
|
||||||
|
"value": threshold_value,
|
||||||
|
} for threshold_value in self.explaining_cfg.threshold.value.topk]
|
||||||
|
elif threshold_type == 'hard':
|
||||||
|
all_threshold = [{
|
||||||
|
"threshold_type": threshold_type,
|
||||||
|
"value": threshold_value,
|
||||||
|
} for threshold_value in self.explaining_cfg.threshold.value.hard]
|
||||||
|
elif isinstance(threshold_type,list):
|
||||||
|
all_threshold = []
|
||||||
|
for tf_type in threshold_type:
|
||||||
|
if 'topk' in th_type:
|
||||||
|
all_threshold.expend([{
|
||||||
|
"threshold_type": threshold_type,
|
||||||
|
"value": threshold_value,
|
||||||
|
} for threshold_value in self.explaining_cfg.threshold.value.topk])
|
||||||
|
elif th_type == 'hard':
|
||||||
|
all_threshold.expend([{
|
||||||
|
"threshold_type": threshold_type,
|
||||||
|
"value": threshold_value,
|
||||||
|
} for threshold_value in self.explaining_cfg.threshold.value.hard])
|
||||||
|
|
||||||
|
elif threshold_type is None:
|
||||||
|
all_threshold = []
|
||||||
|
self.thresholds_configs = all_threshold
|
||||||
|
|
||||||
|
def set_explainer_threshold_config(self,threshold_config):
|
||||||
|
self.explainer.threshold_config = ThresholdConfig.cast(threshold_config)
|
||||||
|
|
||||||
|
def load_graphstat(self):
|
||||||
|
self.graphstat = GraphStat()
|
||||||
|
|
||||||
|
def get_explanation_(self,item:Data,path:str):
|
||||||
|
if is_exists(path):
|
||||||
|
if self.explaining_cfg.explainer.force:
|
||||||
|
explanation = get_explanation(self.explainer, item)
|
||||||
|
else:
|
||||||
|
explanation = load_explanation(path)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{name_} is an Attack method that is not supported yet")
|
explanation = get_explanation(explainer, item)
|
||||||
|
save_explanation(explanation,path)
|
||||||
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
|
class Explaining(object):
|
||||||
|
def __init__(self,outline:ExplainingOutline):
|
||||||
|
self.outline = outline
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def explain(self):
|
||||||
|
item, index = self.get_item()
|
||||||
|
not_none = item is None or index is None
|
||||||
|
whœ
|
||||||
|
|
||||||
|
while
|
||||||
|
|
||||||
|
|
|
@ -9,37 +9,29 @@ from torch_geometric.explain.explanation import Explanation
|
||||||
class Adjust(object):
|
class Adjust(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
apply_relu: bool = True,
|
strategy: str = "rpn",
|
||||||
apply_normalize: bool = True,
|
|
||||||
apply_project: bool = True,
|
|
||||||
apply_absolute: bool = False,
|
|
||||||
):
|
):
|
||||||
self.apply_relu = apply_relu
|
self.strategy = strategy
|
||||||
self.apply_normalize = apply_normalize
|
|
||||||
self.apply_project = apply_project
|
|
||||||
self.apply_absolute = apply_absolute
|
|
||||||
|
|
||||||
if self.apply_absolute and self.apply_relu:
|
|
||||||
self.apply_relu = False
|
|
||||||
|
|
||||||
def forward(self, exp: Explanation) -> Explanation:
|
def forward(self, exp: Explanation) -> Explanation:
|
||||||
exp_ = copy.copy(exp)
|
exp_ = copy.copy(exp)
|
||||||
_store = exp_.to_dict()
|
_store = exp_.to_dict()
|
||||||
for k, v in _store.items():
|
for k, v in _store.items():
|
||||||
if "mask" in k:
|
if "mask" in k:
|
||||||
if self.apply_relu:
|
for f_ in self.strategy:
|
||||||
_store[k] = self.relu(v)
|
if f_ == "r":
|
||||||
elif self.apply_absolute:
|
_store[k] = self.relu(v)
|
||||||
_store[k] = self.absolute(v)
|
if f_ == "a":
|
||||||
elif self.apply_project:
|
_store[k] = self.absolute(v)
|
||||||
if "edge" in k:
|
if f_ == "p":
|
||||||
pass
|
if "edge" in k:
|
||||||
else:
|
pass
|
||||||
_store[k] = self.project(v)
|
else:
|
||||||
elif self.apply_normalize:
|
_store[k] = self.project(v)
|
||||||
_store[k] = self.normalize(v)
|
if f_ == "n":
|
||||||
else:
|
_store[k] = self.normalize(v)
|
||||||
continue
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
return exp_
|
return exp_
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,38 @@ from torch_geometric.data import Data
|
||||||
from torch_geometric.explain.explanation import Explanation
|
from torch_geometric.explain.explanation import Explanation
|
||||||
|
|
||||||
|
|
||||||
|
def get_explanation(explainer, item):
|
||||||
|
explanation = explainer(
|
||||||
|
x=item.x,
|
||||||
|
edge_index=item.edge_index,
|
||||||
|
index=int(item.y),
|
||||||
|
target=item.y,
|
||||||
|
)
|
||||||
|
# TODO return None if pas bien plutot
|
||||||
|
assert explanation_verification(explanation)
|
||||||
|
return explanation
|
||||||
|
|
||||||
|
|
||||||
|
def is_empty_graph(data: Data) -> bool:
|
||||||
|
return data.x.shape[0] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_pred(explainer, explanation):
|
||||||
|
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
setattr(explanation, "pred", pred)
|
||||||
|
data = explanation.to_dict()
|
||||||
|
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
|
||||||
|
pred_masked = explainer.get_masked_prediction(
|
||||||
|
x=explanation.x,
|
||||||
|
edge_index=explanation.edge_index,
|
||||||
|
node_mask=data.get("node_mask"),
|
||||||
|
edge_mask=data.get("edge_mask"),
|
||||||
|
)[0]
|
||||||
|
setattr(explanation, "pred_exp", pred_masked)
|
||||||
|
|
||||||
|
|
||||||
def explanation_verification(exp: Explanation) -> bool:
|
def explanation_verification(exp: Explanation) -> bool:
|
||||||
is_good = True
|
is_good = True
|
||||||
masks = [v for k, v in exp.items() if "_mask" in k and isinstance(v, torch.Tensor)]
|
masks = [v for k, v in exp.items() if "_mask" in k and isinstance(v, torch.Tensor)]
|
||||||
|
@ -53,5 +85,4 @@ def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation
|
||||||
norm = torch.norm(input=data[k], p=p, dim=None).item()
|
norm = torch.norm(input=data[k], p=p, dim=None).item()
|
||||||
if norm.item() > 0:
|
if norm.item() > 0:
|
||||||
data[k] = data[k] / norm
|
data[k] = data[k] / norm
|
||||||
|
|
||||||
return exp
|
return exp
|
||||||
|
|
19
main.py
19
main.py
|
@ -19,7 +19,8 @@ from explaining_framework.utils.explaining.cmd_args import parse_args
|
||||||
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
from explaining_framework.utils.explaining.outline import ExplainingOutline
|
||||||
from explaining_framework.utils.explanation.adjust import Adjust
|
from explaining_framework.utils.explanation.adjust import Adjust
|
||||||
from explaining_framework.utils.explanation.io import (
|
from explaining_framework.utils.explanation.io import (
|
||||||
explanation_verification, load_explanation, save_explanation)
|
explanation_verification, get_explanation, get_pred, load_explanation,
|
||||||
|
save_explanation)
|
||||||
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
from explaining_framework.utils.io import (is_exists, obj_config_to_str,
|
||||||
read_json, write_json, write_yaml)
|
read_json, write_json, write_yaml)
|
||||||
|
|
||||||
|
@ -42,17 +43,6 @@ def get_pred(explainer, explanation):
|
||||||
setattr(explanation, "pred_exp", pred_masked)
|
setattr(explanation, "pred_exp", pred_masked)
|
||||||
|
|
||||||
|
|
||||||
def get_explanation(explainer, item):
|
|
||||||
explanation = explainer(
|
|
||||||
x=item.x,
|
|
||||||
edge_index=item.edge_index,
|
|
||||||
index=int(item.y),
|
|
||||||
target=item.y,
|
|
||||||
)
|
|
||||||
assert explanation_verification(explanation)
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
outline = ExplainingOutline(args.explaining_cfg_file)
|
outline = ExplainingOutline(args.explaining_cfg_file)
|
||||||
|
@ -68,6 +58,7 @@ if __name__ == "__main__":
|
||||||
attacks = outline.attacks
|
attacks = outline.attacks
|
||||||
explainer_cfg = outline.explainer_cfg
|
explainer_cfg = outline.explainer_cfg
|
||||||
model_signature = outline.model_signature
|
model_signature = outline.model_signature
|
||||||
|
# RAJOUTER INDEXES
|
||||||
|
|
||||||
# Set seed
|
# Set seed
|
||||||
seed_everything(explaining_cfg.seed)
|
seed_everything(explaining_cfg.seed)
|
||||||
|
@ -77,6 +68,7 @@ if __name__ == "__main__":
|
||||||
makedirs(global_path)
|
makedirs(global_path)
|
||||||
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
|
write_yaml(cfg, os.path.join(global_path, "config.yaml"))
|
||||||
write_json(model_info, os.path.join(global_path, "info.json"))
|
write_json(model_info, os.path.join(global_path, "info.json"))
|
||||||
|
# SET RUN DIR
|
||||||
|
|
||||||
global_path = os.path.join(
|
global_path = os.path.join(
|
||||||
global_path,
|
global_path,
|
||||||
|
@ -85,9 +77,11 @@ if __name__ == "__main__":
|
||||||
makedirs(global_path)
|
makedirs(global_path)
|
||||||
write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
|
write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
|
||||||
write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
|
write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
|
||||||
|
# SET EXPLAIN_DIR
|
||||||
|
|
||||||
global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm))
|
global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm))
|
||||||
makedirs(global_path)
|
makedirs(global_path)
|
||||||
|
# SET UP EXPLAINER
|
||||||
explainer = Explainer(
|
explainer = Explainer(
|
||||||
model=model,
|
model=model,
|
||||||
algorithm=explaining_algorithm,
|
algorithm=explaining_algorithm,
|
||||||
|
@ -102,6 +96,7 @@ if __name__ == "__main__":
|
||||||
return_type=explaining_cfg.model_config.return_type,
|
return_type=explaining_cfg.model_config.return_type,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
# CHERGER SUR LE GPU DIRECT
|
||||||
if not explaining_cfg.dataset.specific_items is None:
|
if not explaining_cfg.dataset.specific_items is None:
|
||||||
indexes = explaining_cfg.dataset.specific_items
|
indexes = explaining_cfg.dataset.specific_items
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue