This commit is contained in:
araison 2023-01-04 09:25:41 +01:00
parent fb012ad723
commit d9628ff947
3 changed files with 163 additions and 187 deletions

View File

@ -3,6 +3,17 @@ import itertools
from typing import Any from typing import Any
from eixgnn.eixgnn import EiXGNN from eixgnn.eixgnn import EiXGNN
from scgnn.scgnn import SCGNN
from torch_geometric import seed_everything
from torch_geometric.data import Batch, Data
from torch_geometric.explain import Explainer
from torch_geometric.explain.config import ThresholdConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
from explaining_framework.config.explainer_config.eixgnn_config import \ from explaining_framework.config.explainer_config.eixgnn_config import \
eixgnn_cfg eixgnn_cfg
from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg from explaining_framework.config.explainer_config.scgnn_config import scgnn_cfg
@ -18,15 +29,11 @@ from explaining_framework.stats.graph.graph_stat import GraphStat
from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo, from explaining_framework.utils.explaining.load_ckpt import (LoadModelInfo,
_load_ckpt) _load_ckpt)
from explaining_framework.utils.explanation.adjust import Adjust from explaining_framework.utils.explanation.adjust import Adjust
from scgnn.scgnn import SCGNN from explaining_framework.utils.explanation.io import (
from torch_geometric.data import Batch, Data _get_explanation, _load_explanation, _save_explanation,
from torch_geometric.explain import Explainer explanation_verification, get_pred)
from torch_geometric.explain.config import ThresholdConfig from explaining_framework.utils.io import (is_exists, obj_config_to_str,
from torch_geometric.graphgym.config import cfg read_json, write_json, write_yaml)
from torch_geometric.graphgym.loader import create_dataset
from torch_geometric.graphgym.model_builder import cfg, create_model
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader.dataloader import DataLoader
all__captum = [ all__captum = [
"LRP", "LRP",
@ -88,10 +95,15 @@ all_robust = [
] ]
all_sparsity = ["l0"] all_sparsity = ["l0"]
adjust_pattern = 'ranp' adjust_pattern = "ranp"
all_adjusts_filters = [''.join(filters) for i in range(len(adjust_pattern)+1)for filters in itertools.permutations(adjust_pattern,i)] all_adjusts_filters = [
"".join(filters)
for i in range(len(adjust_pattern) + 1)
for filters in itertools.permutations(adjust_pattern, i)
]
all_threshold_type = ["topk_hard", "hard", "topk"]
all_threshold_type = ['topk_hard','hard','topk']
class ExplainingOutline(object): class ExplainingOutline(object):
def __init__(self, explaining_cfg_path: str): def __init__(self, explaining_cfg_path: str):
@ -131,6 +143,8 @@ class ExplainingOutline(object):
self.load_threshold() self.load_threshold()
self.load_graphstat() self.load_graphstat()
seed_everything(self.explaining_cfg.seed)
def load_model_to_hardware(self): def load_model_to_hardware(self):
auto_select_device() auto_select_device()
device = self.cfg.accelerator device = self.cfg.accelerator
@ -300,17 +314,23 @@ class ExplainingOutline(object):
if self.explaining_cfg is None: if self.explaining_cfg is None:
self.load_explaining_cfg() self.load_explaining_cfg()
name = self.explaining_cfg.metrics.fidelity.name name = self.explaining_cfg.metrics.fidelity.name
if name == 'all': if name == "all":
all_metrics = [ all_metrics = [
Fidelity(name=name, model=self.model) for name in all_fidelity Fidelity(name=name, model=self.model) for name in all_fidelity
] ]
elif isinstance(name,str): elif isinstance(name, str):
if name in all_fidelity: if name in all_fidelity:
all_metrics = [Fidelity(name=name, model=self.model)] all_metrics = [Fidelity(name=name, model=self.model)]
else: else:
raise ValueError(f'This fidelity metric {name} is nor supported yet. Supported are {all_fidelity}') raise ValueError(
elif isinstance(name,list): f"This fidelity metric {name} is nor supported yet. Supported are {all_fidelity}"
all_metrics = [Fidelity(name=name, model=self.model) for name_ in name if name_ in all_fidelity] )
elif isinstance(name, list):
all_metrics = [
Fidelity(name=name, model=self.model)
for name_ in name
if name_ in all_fidelity
]
elif name is None: elif name is None:
all_metrics = [] all_metrics = []
self.fidelities = all_metrics self.fidelities = all_metrics
@ -321,23 +341,23 @@ class ExplainingOutline(object):
if self.explaining_cfg is None: if self.explaining_cfg is None:
self.load_explaining_cfg() self.load_explaining_cfg()
name = self.explaining_cfg.metrics.sparsity.name name = self.explaining_cfg.metrics.sparsity.name
if name == 'all': if name == "all":
all_metrics = [ all_metrics = [Sparsity(name=name) for name in all_sparsity]
Sparsity(name=name) for name in all_sparsity elif isinstance(name, str):
]
elif isinstance(name,str):
if name in all_sparsity: if name in all_sparsity:
all_metrics = [Sparsity(name=name)] all_metrics = [Sparsity(name=name)]
else: else:
raise ValueError(f'This sparsity metric {name} is nor supported yet. Supported are {all_sparsity}') raise ValueError(
elif isinstance(name,list): f"This sparsity metric {name} is nor supported yet. Supported are {all_sparsity}"
all_metrics = [Sparsity(name=name) for name_ in name if name_ in all_sparsity] )
elif isinstance(name, list):
all_metrics = [
Sparsity(name=name) for name_ in name if name_ in all_sparsity
]
elif name is None: elif name is None:
all_metrics = [] all_metrics = []
self.sparsities = all_metrics self.sparsities = all_metrics
def load_accuracy(self): def load_accuracy(self):
if self.cfg is None: if self.cfg is None:
self.load_cfg() self.load_cfg()
@ -346,24 +366,26 @@ class ExplainingOutline(object):
if self.explaining_cfg.dataset.name == "BASHAPES": if self.explaining_cfg.dataset.name == "BASHAPES":
name = self.explaining_cfg.metrics.accuracy.name name = self.explaining_cfg.metrics.accuracy.name
if name == 'all': if name == "all":
all_metrics = [ all_metrics = [Accuracy(name=name) for name in all_accuracy]
Accuracy(name=name) for name in all_accuracy elif isinstance(name, str):
]
elif isinstance(name,str):
if name in all_accuracy: if name in all_accuracy:
all_metrics = [Accuracy(name=name)] all_metrics = [Accuracy(name=name)]
else: else:
raise ValueError(f'This accuracy metric {name} is nor supported yet. Supported are {all_accuracy}') raise ValueError(
elif isinstance(name,list): f"This accuracy metric {name} is nor supported yet. Supported are {all_accuracy}"
all_metrics = [Accuracy(name=name) for name_ in name if name_ in all_accuracy] )
elif isinstance(name, list):
all_metrics = [
Accuracy(name=name) for name_ in name if name_ in all_accuracy
]
elif name is None: elif name is None:
all_metrics = [] all_metrics = []
self.accuraties = all_metrics self.accuraties = all_metrics
else: else:
raise ValueError(f'Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset') raise ValueError(
f"Provided dataset needs explanation groundtruths for using Accuracies metric, e.g BASHAPES dataset"
)
def load_metric(self): def load_metric(self):
if self.cfg is None: if self.cfg is None:
@ -377,8 +399,7 @@ class ExplainingOutline(object):
if self.fidelities is None: if self.fidelities is None:
self.load_fidelity() self.load_fidelity()
self.metrics = self.fidelities+self.accuraties+self.sparsities self.metrics = self.fidelities + self.accuraties + self.sparsities
def load_attack(self): def load_attack(self):
if self.cfg is None: if self.cfg is None:
@ -386,17 +407,21 @@ class ExplainingOutline(object):
if self.explaining_cfg is None: if self.explaining_cfg is None:
self.load_explaining_cfg() self.load_explaining_cfg()
name = self.explaining_cfg.attack.name name = self.explaining_cfg.attack.name
if name == 'all': if name == "all":
all_metrics = [ all_metrics = [Attack(name=name, model=self.model) for name in all_robust]
Attack(name=name,model=self.model) for name in all_robust elif isinstance(name, str):
]
elif isinstance(name,str):
if name in all_robust: if name in all_robust:
all_metrics = [Attack(name=name,model=self.model)] all_metrics = [Attack(name=name, model=self.model)]
else: else:
raise ValueError(f'This Attack metric {name} is not supported yet. Supported are {all_robust}') raise ValueError(
elif isinstance(name,list): f"This Attack metric {name} is not supported yet. Supported are {all_robust}"
all_metrics = [Attack(name=name,model=self.model) for name_ in name if name_ in all_robust] )
elif isinstance(name, list):
all_metrics = [
Attack(name=name, model=self.model)
for name_ in name
if name_ in all_robust
]
elif name is None: elif name is None:
all_metrics = [] all_metrics = []
self.attacks = all_metrics self.attacks = all_metrics
@ -407,13 +432,17 @@ class ExplainingOutline(object):
strategy = self.explaining_cfg.adjust.strategy strategy = self.explaining_cfg.adjust.strategy
if strategy == "all": if strategy == "all":
self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters] self.adjusts = [Adjust(strategy=strat) for strat in all_adjusts_filters]
elif isinstance(name,str): elif isinstance(name, str):
if name in all_adjusts_filters: if name in all_adjusts_filters:
all_metrics = [Adjust(strategy=name)] all_metrics = [Adjust(strategy=name)]
else: else:
raise ValueError(f'This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}') raise ValueError(
elif isinstance(name,list): f"This Adjust metric {name} is not supported yet. Supported are {all_adjusts_filters}"
all_metrics = [Adjust(strategy=name_) for name_ in name if name_ in all_robust] )
elif isinstance(name, list):
all_metrics = [
Adjust(strategy=name_) for name_ in name if name_ in all_robust
]
elif name is None: elif name is None:
all_metrics = [] all_metrics = []
self.adjusts = all_metrics self.adjusts = all_metrics
@ -421,70 +450,90 @@ class ExplainingOutline(object):
def load_threshold(self): def load_threshold(self):
if self.explaining_cfg is None: if self.explaining_cfg is None:
self.load_explaining_cfg() self.load_explaining_cfg()
threshold_type =self.explaining_cfg.threshold_config.type threshold_type = self.explaining_cfg.threshold_config.type
if threshold_type == 'all': if threshold_type == "all":
th_hard = [{"threshold_type": 'hard',"value": th_value} for th_value in self.explaining_cfg.threshold.value.hard] th_hard = [
th_topk = [{"threshold_type": th_type,"value": th_value} for th_value in self.explaining_cfg.threshold.value.topk f or th_type in all_threshold_type if 'topk' in th_type] {"threshold_type": "hard", "value": th_value}
for th_value in self.explaining_cfg.threshold.value.hard
]
th_topk = [
{"threshold_type": th_type, "value": th_value}
for th_value in self.explaining_cfg.threshold.value.topk
for th_type in all_threshold_type
if "topk" in th_type
]
all_threshold = th_hard + th_topk all_threshold = th_hard + th_topk
elif isinstance(threshold_type,str): elif isinstance(threshold_type, str):
if threshold_type in all_threshold_type: if threshold_type in all_threshold_type:
if 'topk' in threshold_type: if "topk" in threshold_type:
all_threshold = [{ all_threshold = [
{
"threshold_type": threshold_type, "threshold_type": threshold_type,
"value": threshold_value, "value": threshold_value,
} for threshold_value in self.explaining_cfg.threshold.value.topk] }
elif threshold_type == 'hard': for threshold_value in self.explaining_cfg.threshold.value.topk
all_threshold = [{ ]
elif threshold_type == "hard":
all_threshold = [
{
"threshold_type": threshold_type, "threshold_type": threshold_type,
"value": threshold_value, "value": threshold_value,
} for threshold_value in self.explaining_cfg.threshold.value.hard] }
elif isinstance(threshold_type,list): for threshold_value in self.explaining_cfg.threshold.value.hard
]
elif isinstance(threshold_type, list):
all_threshold = [] all_threshold = []
for tf_type in threshold_type: for tf_type in threshold_type:
if 'topk' in th_type: if "topk" in th_type:
all_threshold.expend([{ all_threshold.expend(
[
{
"threshold_type": threshold_type, "threshold_type": threshold_type,
"value": threshold_value, "value": threshold_value,
} for threshold_value in self.explaining_cfg.threshold.value.topk]) }
elif th_type == 'hard': for threshold_value in self.explaining_cfg.threshold.value.topk
all_threshold.expend([{ ]
)
elif th_type == "hard":
all_threshold.expend(
[
{
"threshold_type": threshold_type, "threshold_type": threshold_type,
"value": threshold_value, "value": threshold_value,
} for threshold_value in self.explaining_cfg.threshold.value.hard]) }
for threshold_value in self.explaining_cfg.threshold.value.hard
]
)
elif threshold_type is None: elif threshold_type is None:
all_threshold = [] all_threshold = []
self.thresholds_configs = all_threshold self.thresholds_configs = all_threshold
def set_explainer_threshold_config(self,threshold_config): def set_explainer_threshold_config(self, threshold_config):
self.explainer.threshold_config = ThresholdConfig.cast(threshold_config) self.explainer.threshold_config = ThresholdConfig.cast(threshold_config)
def load_graphstat(self): def load_graphstat(self):
self.graphstat = GraphStat() self.graphstat = GraphStat()
def get_explanation_(self,item:Data,path:str): def get_explanation(self, item: Data, path: str):
if is_exists(path): if is_exists(path):
if self.explaining_cfg.explainer.force: if self.explaining_cfg.explainer.force:
explanation = get_explanation(self.explainer, item) explanation = _get_explanation(self.explainer, item)
else: else:
explanation = load_explanation(path) explanation = _load_explanation(path)
else: else:
explanation = get_explanation(explainer, item) explanation = _get_explanation(self.explainer, item)
save_explanation(explanation,path) _save_explanation(explanation, path)
explanation = explanation.to(cfg.accelerator)
return explanation return explanation
def get_stat(self, item: Data, path: str):
class Explaining(object): if self.graphstat is None:
def __init__(self,outline:ExplainingOutline): self.load_graphstat()
self.outline = outline if is_exists(path):
def run(self):
pass pass
else:
def explain(self): if item.num_nodes <= 500:
item, index = self.get_item() stat = self.graphstat(item)
not_none = item is None or index is None write_json(stat, path)
whœ
while

View File

@ -7,15 +7,17 @@ from torch_geometric.data import Data
from torch_geometric.explain.explanation import Explanation from torch_geometric.explain.explanation import Explanation
def get_explanation(explainer, item): def _get_explanation(explainer, item):
explanation = explainer( explanation = explainer(
x=item.x, x=item.x,
edge_index=item.edge_index, edge_index=item.edge_index,
index=int(item.y), index=int(item.y),
target=item.y, target=item.y,
) )
# TODO return None if pas bien plutot if not explanation_verification(explanation):
assert explanation_verification(explanation) # WARNING + LOG
return None
else:
return explanation return explanation
@ -55,7 +57,7 @@ def explanation_verification(exp: Explanation) -> bool:
return is_good return is_good
def save_explanation(exp: Explanation, path: str) -> None: def _save_explanation(exp: Explanation, path: str) -> None:
data = copy.copy(exp).to_dict() data = copy.copy(exp).to_dict()
for k, v in data.items(): for k, v in data.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
@ -65,7 +67,7 @@ def save_explanation(exp: Explanation, path: str) -> None:
json.dump(data, f) json.dump(data, f)
def load_explanation(path: str) -> Explanation: def _load_explanation(path: str) -> Explanation:
with open(path, "r") as f: with open(path, "r") as f:
data = json.load(f) data = json.load(f)
for k, v in data.items(): for k, v in data.items():
@ -77,12 +79,3 @@ def load_explanation(path: str) -> Explanation:
return Explanation.from_dict(data) return Explanation.from_dict(data)
def normalize_explanation_masks(exp: Explanation, p: str = "inf") -> Explanation:
exp = copy.copy(exp)
data = exp.to_dict()
for k, v in data.items():
if "_mask" in k and isinstance(v, torch.FloatTensor):
norm = torch.norm(input=data[k], p=p, dim=None).item()
if norm.item() > 0:
data[k] = data[k] / norm
return exp

82
main.py
View File

@ -27,99 +27,33 @@ from explaining_framework.utils.io import (is_exists, obj_config_to_str,
# inference, time, force, # inference, time, force,
def get_pred(explainer, explanation):
pred = explainer.get_prediction(x=explanation.x, edge_index=explanation.edge_index)[
0
]
setattr(explanation, "pred", pred)
data = explanation.to_dict()
if not data.get("node_mask") is None or not data.get("edge_mask") is None:
pred_masked = explainer.get_masked_prediction(
x=explanation.x,
edge_index=explanation.edge_index,
node_mask=data.get("node_mask"),
edge_mask=data.get("edge_mask"),
)[0]
setattr(explanation, "pred_exp", pred_masked)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
outline = ExplainingOutline(args.explaining_cfg_file) outline = ExplainingOutline(args.explaining_cfg_file)
auto_select_device()
# Load components # Load components
dataset = outline.dataset
model = outline.model.to(cfg.accelerator)
model = model.eval()
model_info = outline.model_info
metrics = outline.metrics
explaining_algorithm = outline.explaining_algorithm
attacks = outline.attacks
explainer_cfg = outline.explainer_cfg
model_signature = outline.model_signature
# RAJOUTER INDEXES # RAJOUTER INDEXES
# Set seed
seed_everything(explaining_cfg.seed)
# Global path # Global path
global_path = os.path.join(explaining_cfg.out_dir, model_signature) global_path = os.path.join(outline.explaining_cfg.out_dir, outline.model_signature, outline.explaining_cfg.explainer.name + "_" + obj_config_to_str(outline.explaining_algorithm))
makedirs(global_path) makedirs(global_path)
write_yaml(cfg, os.path.join(global_path, "config.yaml")) write_yaml(cfg, os.path.join(global_path, "config.yaml"))
write_json(model_info, os.path.join(global_path, "info.json")) write_json(model_info, os.path.join(global_path, "info.json"))
# SET RUN DIR
global_path = os.path.join(
global_path,
explaining_cfg.explainer.name + "_" + obj_config_to_str(explaining_algorithm),
)
makedirs(global_path) makedirs(global_path)
write_yaml(explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest)) write_yaml(outline.explaining_cfg, os.path.join(global_path, explaining_cfg.cfg_dest))
write_yaml(explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml")) write_yaml(outline.explainer_cfg, os.path.join(global_path, "explainer_cfg.yaml"))
# SET EXPLAIN_DIR
global_path = os.path.join(global_path, obj_config_to_str(explaining_algorithm)) global_path = os.path.join(global_path, obj_config_to_str(outline.explaining_algorithm))
makedirs(global_path) makedirs(global_path)
# SET UP EXPLAINER # SET UP EXPLAINER
explainer = Explainer(
model=model,
algorithm=explaining_algorithm,
explainer_config=dict(
explanation_type=explaining_cfg.explanation_type,
node_mask_type="object",
edge_mask_type="object",
),
model_config=dict(
mode="regression",
task_level=cfg.dataset.task,
return_type=explaining_cfg.model_config.return_type,
),
)
# CHERGER SUR LE GPU DIRECT
if not explaining_cfg.dataset.specific_items is None:
indexes = explaining_cfg.dataset.specific_items
else:
indexes = range(len(dataset))
# Save explaining configuration # Save explaining configuration
for index, item in zip(indexes, dataset): item,index = outline.get_item()
item = item.to(cfg.accelerator) while not(item is None or index is None):
save_raw_path = os.path.join(global_path, "raw") raw_path = os.path.join(global_path, "raw")
makedirs(save_raw_path) makedirs(raw_path)
explanation_path = os.path.join(save_raw_path, f"{index}.json") explanation_path = os.path.join(save_raw_path, f"{index}.json")
if is_exists(explanation_path):
if explaining_cfg.explainer.force:
explanation = get_explanation(explainer, item)
else:
explanation = load_explanation(explanation_path)
else:
explanation = get_explanation(explainer, item)
explanation = explanation.to(cfg.accelerator)
get_pred(explainer=explainer, explanation=explanation)
save_explanation(explanation, explanation_path)
for apply_relu in [True, False]: for apply_relu in [True, False]:
for apply_absolute in [True, False]: for apply_absolute in [True, False]:
adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute) adjust = Adjust(apply_relu=apply_relu, apply_absolute=apply_absolute)