Merge branch 'master' of gitlab.xlim.fr:araison/explaining_framework
This commit is contained in:
		
						commit
						b73f087a6a
					
				
					 43 changed files with 428 additions and 0 deletions
				
			
		
							
								
								
									
										0
									
								
								explaining_framework/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/__init__.py
									
										
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										144
									
								
								explaining_framework/explainers/algorithms/deconvolution.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								explaining_framework/explainers/algorithms/deconvolution.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,144 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | 
 | ||||||
|  | import logging | ||||||
|  | from typing import Optional, Tuple, Union | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from captum.attr import (LRP, DeepLift, DeepLiftShap, FeatureAblation, | ||||||
|  |                          FeaturePermutation, GradientShap, GuidedBackprop, | ||||||
|  |                          GuidedGradCam, InputXGradient, IntegratedGradient, | ||||||
|  |                          Lime, Occlusion, Saliency) | ||||||
|  | from torch import Tensor | ||||||
|  | from torch_geometric.data import Batch, Data | ||||||
|  | from torch_geometric.explain import Explanation | ||||||
|  | from torch_geometric.explain.algorithm.base import ExplainerAlgorithm | ||||||
|  | from torch_geometric.explain.config import (ExplainerConfig, MaskType, | ||||||
|  |                                             ModelConfig, ModelMode, | ||||||
|  |                                             ModelTaskLevel) | ||||||
|  | from torch_geometric.nn import GCNConv, global_mean_pooling | ||||||
|  | from torch_geometric.nn.models.captum import (_raise_on_invalid_mask_type, | ||||||
|  |                                               to_captum_input, to_captum_model) | ||||||
|  | 
 | ||||||
|  | import torch.nn.functional as F | ||||||
|  | 
 | ||||||
|  | class FromCaptum(ExplainerAlgorithm): | ||||||
|  |     def __init__(captum_method, mask_type: str = "node", **kwargs): | ||||||
|  |         super().__init__() | ||||||
|  |         self.captum_model = captum_model | ||||||
|  |         self.mask_type = mask_type | ||||||
|  |         _raise_on_invalid_mask_type(mask_type) | ||||||
|  |         self.kwargs = kwargs | ||||||
|  | 
 | ||||||
|  |     def supports(self) -> bool: | ||||||
|  |         task_level = self.model_config.task_level | ||||||
|  |         if task_level not in [ModelTaskLevel.graph]: | ||||||
|  |             logging.error(f"Task level '{task_level.value}' not supported") | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |         edge_mask_type = self.explainer_config.edge_mask_type | ||||||
|  |         if edge_mask_type not in [MaskType.object, None]: | ||||||
|  |             logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported") | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |         node_mask_type = self.explainer_config.node_mask_type | ||||||
|  |         if node_mask_type not in [MaskType.attributes]: | ||||||
|  |             logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.") | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |         return True | ||||||
|  | 
 | ||||||
|  |     def attr_to_tuple(self, attr, mask_type): | ||||||
|  |         _raise_on_invalid_mask_type(mask_type) | ||||||
|  |         if "node" == mask_type: | ||||||
|  |             node_mask = attr | ||||||
|  |             edge_mask = None | ||||||
|  | 
 | ||||||
|  |         if "edge" == mask_type: | ||||||
|  |             node_mask = None | ||||||
|  |             edge_mask = attr | ||||||
|  | 
 | ||||||
|  |         if "node_and_mask" == mask_type: | ||||||
|  |             node_mask = attr[0] | ||||||
|  |             edge_mask = attr[1] | ||||||
|  | 
 | ||||||
|  |     def forward( | ||||||
|  |         self, | ||||||
|  |         model: torch.nn.Module, | ||||||
|  |         x: Tensor, | ||||||
|  |         edge_index: Tensor, | ||||||
|  |         target: None, | ||||||
|  |     ) -> Explanation: | ||||||
|  | 
 | ||||||
|  |         converted_model = to_captum_model( | ||||||
|  |             model, | ||||||
|  |             mask_type=self.mask_type, | ||||||
|  |             output_idx=target, | ||||||
|  |         ) | ||||||
|  |         attrib = self.captum_model(converted_model) | ||||||
|  | 
 | ||||||
|  |         inputs, additional_forward_args = to_captum_input( | ||||||
|  |             x, edge_index, mask_type=self.mask_type | ||||||
|  |         ) | ||||||
|  |         attr = attrib.attribute( | ||||||
|  |             inputs, | ||||||
|  |             target=target, | ||||||
|  |             additional_forward_args=additional_forward_args, | ||||||
|  |             **self.kwargs, | ||||||
|  |         ) | ||||||
|  |         node_mask, edge_mask = self.attr_to_tuple(attr, self.mask_type) | ||||||
|  | 
 | ||||||
|  |         return Explanation( | ||||||
|  |             x=x, edge_index=edge_index, edge_mask=edge_mask, node_mask=node_mask | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if "__name__" == __main__: | ||||||
|  |     edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) | ||||||
|  |     x = torch.tensor([[-1], [0], [1]], dtype=torch.float) | ||||||
|  |     y = torch.tensor([1], dtype=torch.long) | ||||||
|  | 
 | ||||||
|  |     data = Data(x=x, edge_index=edge_index, y=y) | ||||||
|  |     batch = Batch().from_data_list([data]) | ||||||
|  | 
 | ||||||
|  |     class Model(torch.nn.Module): | ||||||
|  |         def __init__(self, dim_in, dim_out): | ||||||
|  |             super().__init__() | ||||||
|  |             self.dim_in = dim_in | ||||||
|  |             self.dim_out = dim_out | ||||||
|  |             self.conv = GCNConv(dim_in, dim_out) | ||||||
|  | 
 | ||||||
|  |         def forward(self, batch): | ||||||
|  |             x, edge_index, batch = batch.x, batch.edge_index, batch.batch | ||||||
|  |             x, edge_index = self.conv(x, edge_index) | ||||||
|  |             x = global_mean_pooling(x, edge_index, batch) | ||||||
|  |             return x | ||||||
|  | 
 | ||||||
|  |     model = Model(1,2) | ||||||
|  |     optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) | ||||||
|  | 
 | ||||||
|  |     for epoch in range(1, 2): | ||||||
|  |         model.train() | ||||||
|  |         optimizer.zero_grad() | ||||||
|  |         log_logits = model(batch) | ||||||
|  |         loss = F.mse_loss(out,torch.ones(2)) | ||||||
|  |         loss.backward() | ||||||
|  |         optimizer.step() | ||||||
|  | 
 | ||||||
|  |     expla = [LRP, DeepLift, DeepLiftShap, FeatureAblation, | ||||||
|  |                          FeaturePermutation, GradientShap, GuidedBackprop, | ||||||
|  |                          GuidedGradCam, InputXGradient, IntegratedGradient, | ||||||
|  |                          Lime, Occlusion, Saliency] | ||||||
|  |     model = model.eval() | ||||||
|  |     for captum_exp in expla: | ||||||
|  |         try: | ||||||
|  |             exp = FromCaptum(captum_exp) | ||||||
|  |             attr = exp.forward(model,batch.x,batch.edge_index) | ||||||
|  |         except Exception as e: | ||||||
|  |             print(str(e)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |      | ||||||
|  | 
 | ||||||
							
								
								
									
										0
									
								
								explaining_framework/explainers/algorithms/deeplift.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/algorithms/deeplift.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/algorithms/eixgnn.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/algorithms/eixgnn.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/algorithms/graphsvx.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/algorithms/graphsvx.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/algorithms/lime.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/algorithms/lime.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/algorithms/lrp.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/algorithms/lrp.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/algorithms/occlusion.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/algorithms/occlusion.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/algorithms/random.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/algorithms/random.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/algorithms/saliency.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/algorithms/saliency.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/algorithms/subgraphx.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/algorithms/subgraphx.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/wrappers/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/wrappers/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/wrappers/captum.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/wrappers/captum.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/wrappers/dig.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/wrappers/dig.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/wrappers/graphframex.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/wrappers/graphframex.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/explainers/wrappers/graphxai.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/explainers/wrappers/graphxai.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/metric/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/metric/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										24
									
								
								explaining_framework/metric/dataset.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								explaining_framework/metric/dataset.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,24 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | 
 | ||||||
|  | from torch_geometric.data import Data, Batch, Dataset | ||||||
|  | from torch_geometric.utils import to_networkx | ||||||
|  | import networkx as nx | ||||||
|  | import numpy as np | ||||||
|  | import pandas as pd | ||||||
|  | 
 | ||||||
|  | class DatasetStatistic(object): | ||||||
|  |     def __init__(self, dataset): | ||||||
|  |         if not isinstance(dataset, Dataset): | ||||||
|  |             raise ValueError(f'{dataset} needs to be an PyG dataset object') | ||||||
|  |         self.dataset = dataset | ||||||
|  | 
 | ||||||
|  |         if len(dataset) == 0: | ||||||
|  |             self.task = 'node'  | ||||||
|  |         else: | ||||||
|  |             self.task = 'graph' | ||||||
|  |         self.study = None  | ||||||
|  | 
 | ||||||
|  |     def save_csv(self, path): | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
							
								
								
									
										28
									
								
								explaining_framework/metric/test.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								explaining_framework/metric/test.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,28 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | 
 | ||||||
|  | from inspect import getmembers, isfunction, signature | ||||||
|  | import networkx as nx | ||||||
|  | from docstring_parser import parse | ||||||
|  | 
 | ||||||
|  | functions_list = getmembers(nx, isfunction) | ||||||
|  | rt_type =[] | ||||||
|  | for func in functions_list: | ||||||
|  |     name, f = func | ||||||
|  |     docstring = parse(f.__doc__) | ||||||
|  |     try: | ||||||
|  |         rt = docstring.returns.type_name | ||||||
|  |         rt_type.append(rt) | ||||||
|  |         if rt == 'int' or rt == 'float' or rt=='bool' or rt=='boolean' or rt=='dictionary': | ||||||
|  |             print(f'{name} : {rt}') | ||||||
|  |     except AttributeError: | ||||||
|  |         continue | ||||||
|  | 
 | ||||||
|  | print('int', rt_type.count('int')) | ||||||
|  | print('float', rt_type.count('float')) | ||||||
|  | print('bool', rt_type.count('bool')) | ||||||
|  | print('boolean', rt_type.count('boolean')) | ||||||
|  | print('dictionary', rt_type.count('dictionary')) | ||||||
|  | print('length func',len(functions_list)) | ||||||
|  | print('length rt',len(rt_type)) | ||||||
|  | 
 | ||||||
							
								
								
									
										0
									
								
								explaining_framework/stats/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/stats/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/stats/dataset/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/stats/dataset/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/stats/graph/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/stats/graph/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										228
									
								
								explaining_framework/stats/graph/graph_stat.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								explaining_framework/stats/graph/graph_stat.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,228 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | 
 | ||||||
|  | import types | ||||||
|  | from inspect import getmembers, isfunction, signature | ||||||
|  | 
 | ||||||
|  | import networkx as nx | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | from torch_geometric.data import Data | ||||||
|  | from torch_geometric.utils import to_networkx | ||||||
|  | 
 | ||||||
|  | __maps__ = [ | ||||||
|  |     "adamic_adar_index", | ||||||
|  |     "approximate_current_flow_betweenness_centrality", | ||||||
|  |     "average_clustering", | ||||||
|  |     "average_degree_connectivity", | ||||||
|  |     "average_neighbor_degree", | ||||||
|  |     "average_node_connectivity", | ||||||
|  |     "average_shortest_path_length", | ||||||
|  |     "betweenness_centrality", | ||||||
|  |     "bridges", | ||||||
|  |     "closeness_centrality", | ||||||
|  |     "clustering", | ||||||
|  |     "cn_soundarajan_hopcroft", | ||||||
|  |     "common_neighbor_centrality", | ||||||
|  |     "communicability", | ||||||
|  |     "communicability_betweenness_centrality", | ||||||
|  |     "communicability_exp", | ||||||
|  |     "connected_components", | ||||||
|  |     "constraint", | ||||||
|  |     "core_number", | ||||||
|  |     "current_flow_betweenness_centrality", | ||||||
|  |     "current_flow_closeness_centrality", | ||||||
|  |     "cycle_basis", | ||||||
|  |     "degree_assortativity_coefficient", | ||||||
|  |     "degree_centrality", | ||||||
|  |     "degree_mixing_matrix", | ||||||
|  |     "degree_pearson_correlation_coefficient", | ||||||
|  |     "diameter", | ||||||
|  |     "dispersion", | ||||||
|  |     "dominating_set", | ||||||
|  |     "eccentricity", | ||||||
|  |     "effective_size", | ||||||
|  |     "eigenvector_centrality", | ||||||
|  |     "estrada_index", | ||||||
|  |     "generalized_degree", | ||||||
|  |     "global_efficiency", | ||||||
|  |     "global_reaching_centrality", | ||||||
|  |     "graph_clique_number", | ||||||
|  |     "harmonic_centrality", | ||||||
|  |     "has_bridges", | ||||||
|  |     "has_eulerian_path", | ||||||
|  |     "hits", | ||||||
|  |     "information_centrality", | ||||||
|  |     "is_at_free", | ||||||
|  |     "is_biconnected", | ||||||
|  |     "is_bipartite", | ||||||
|  |     "is_chordal", | ||||||
|  |     "is_connected", | ||||||
|  |     "is_directed_acyclic_graph", | ||||||
|  |     "is_distance_regular", | ||||||
|  |     "is_eulerian", | ||||||
|  |     "is_forest", | ||||||
|  |     "is_graphical", | ||||||
|  |     "is_multigraphical", | ||||||
|  |     "is_planar", | ||||||
|  |     "is_pseudographical", | ||||||
|  |     "is_regular", | ||||||
|  |     "is_semieulerian", | ||||||
|  |     "is_strongly_regular", | ||||||
|  |     "is_tree", | ||||||
|  |     "is_valid_degree_sequence_erdos_gallai", | ||||||
|  |     "is_valid_degree_sequence_havel_hakimi", | ||||||
|  |     "jaccard_coefficient", | ||||||
|  |     "k_components", | ||||||
|  |     "k_core", | ||||||
|  |     "katz_centrality", | ||||||
|  |     "load_centrality", | ||||||
|  |     "minimum_cycle_basis", | ||||||
|  |     "minimum_edge_cut", | ||||||
|  |     "minimum_node_cut", | ||||||
|  |     "node_clique_number", | ||||||
|  |     "node_connectivity", | ||||||
|  |     "node_degree_xy", | ||||||
|  |     "number_connected_components", | ||||||
|  |     "number_of_cliques", | ||||||
|  |     "number_of_isolates", | ||||||
|  |     "pagerank", | ||||||
|  |     "periphery", | ||||||
|  |     "preferential_attachment", | ||||||
|  |     "reciprocity", | ||||||
|  |     "resource_allocation_index", | ||||||
|  |     "rich_club_coefficient", | ||||||
|  |     "square_clustering", | ||||||
|  |     "stoer_wagner", | ||||||
|  |     "subgraph_centrality", | ||||||
|  |     "subgraph_centrality_exp", | ||||||
|  |     "transitivity", | ||||||
|  |     "triangles", | ||||||
|  |     "voterank", | ||||||
|  |     "wiener_index", | ||||||
|  |     "algebraic_connectivity", | ||||||
|  |     "degree", | ||||||
|  |     "density", | ||||||
|  |     "normalized_laplacian_spectrum", | ||||||
|  |     "number_of_edges", | ||||||
|  |     "number_of_nodes", | ||||||
|  |     "number_of_selfloops", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GraphStat(object): | ||||||
|  |     def __init__(self): | ||||||
|  |         self.maps = { | ||||||
|  |             "networkx": self.available_map_networkx(), | ||||||
|  |             "torch_geometric": self.available_map_torch_geometric(), | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     def available_map_networkx(self): | ||||||
|  |         functions_list = getmembers(nx, isfunction) | ||||||
|  | 
 | ||||||
|  |         maps = {} | ||||||
|  |         for func in functions_list: | ||||||
|  |             name, f = func | ||||||
|  |             if name in __maps__: | ||||||
|  |                 maps[name] = f | ||||||
|  |         return maps | ||||||
|  | 
 | ||||||
|  |     def available_map_torch_geometric(self): | ||||||
|  |         names = [ | ||||||
|  |             "num_nodes", | ||||||
|  |             "num_edges", | ||||||
|  |             "has_self_loops", | ||||||
|  |             "has_isolated_nodes", | ||||||
|  |             "num_nodes_features", | ||||||
|  |             "y", | ||||||
|  |         ] | ||||||
|  |         maps = { | ||||||
|  |             name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None | ||||||
|  |             for name in names | ||||||
|  |         } | ||||||
|  |         return maps | ||||||
|  | 
 | ||||||
|  |     def __call__(self, data): | ||||||
|  |         data_ = data.__copy__() | ||||||
|  |         datahash = hash(data.__repr__) | ||||||
|  |         stats = {} | ||||||
|  |         for k, v in self.maps.items(): | ||||||
|  |             if k == "networkx": | ||||||
|  |                 _data_ = to_networkx(data) | ||||||
|  |                 _data_ = _data_.to_undirected() | ||||||
|  |             elif k == "torch_geometric": | ||||||
|  |                 _data_ = data.__copy__() | ||||||
|  |             for name, func in v.items(): | ||||||
|  |                 try: | ||||||
|  |                     val = func(_data_) | ||||||
|  |                 except: | ||||||
|  |                     val = None | ||||||
|  |                 if callable(val): | ||||||
|  |                     val = val() | ||||||
|  |                 if isinstance(val, types.GeneratorType): | ||||||
|  |                     try: | ||||||
|  |                         val = dict(val) | ||||||
|  |                     except: | ||||||
|  |                         val = None | ||||||
|  |                 if name == "hits": | ||||||
|  |                     val = val[0] | ||||||
|  |                 if name == "k_components": | ||||||
|  |                     for key, value in val.items(): | ||||||
|  |                         val[key] = list(val[key][0]) | ||||||
|  | 
 | ||||||
|  |                 stats[name] = self.convert(val) | ||||||
|  |         stats["hash"] = datahash | ||||||
|  |         return stats | ||||||
|  | 
 | ||||||
|  |     def convert(self, val, K=4): | ||||||
|  | 
 | ||||||
|  |         if type(val) == set: | ||||||
|  |             val = list(val) | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, torch.Tensor): | ||||||
|  |             val = val.cpu().numpy().tolist() | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, list): | ||||||
|  |             for ind in range(len(val)): | ||||||
|  |                 item = val[ind] | ||||||
|  |                 if isinstance(item, list): | ||||||
|  |                     for subind in range(len(item)): | ||||||
|  |                         subitem = item[subind] | ||||||
|  |                         if type(subitem) == float: | ||||||
|  |                             item[subind] = round(subitem, K) | ||||||
|  | 
 | ||||||
|  |                 else: | ||||||
|  |                     if type(item) == float: | ||||||
|  |                         val[ind] = round(item, K) | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, dict): | ||||||
|  |             for k, v in val.items(): | ||||||
|  |                 if isinstance(v, dict): | ||||||
|  |                     for k1, v1 in v.items(): | ||||||
|  |                         if type(v1) == float: | ||||||
|  |                             v[k1] = round(v1, K) | ||||||
|  |                         if type(v1) == set: | ||||||
|  |                             v[k1] = list(v1) | ||||||
|  | 
 | ||||||
|  |                 else: | ||||||
|  |                     if type(v) == float: | ||||||
|  |                         val[k] = round(v, K) | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, nx.classes.reportviews.DegreeView): | ||||||
|  |             val = np.array(val).tolist() | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, nx.Graph): | ||||||
|  |             val = nx.node_link_data(val) | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, np.ndarray): | ||||||
|  |             val = np.around(val, decimals=K) | ||||||
|  |             return val.tolist() | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             return val | ||||||
							
								
								
									
										0
									
								
								explaining_framework/utils/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/utils/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/utils/explainer/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/utils/explainer/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										3
									
								
								explaining_framework/utils/explainer/base.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								explaining_framework/utils/explainer/base.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,3 @@ | ||||||
|  | class BaseExplaining(object): | ||||||
|  |     def __init__(self,model,explainer_name:wq | ||||||
|  | 
 | ||||||
							
								
								
									
										1
									
								
								explaining_framework/utils/explainer/from_captum.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								explaining_framework/utils/explainer/from_captum.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1 @@ | ||||||
|  | from torch_geometric.nn.models.captum import CaptumModel | ||||||
							
								
								
									
										0
									
								
								explaining_framework/utils/graphgym/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/utils/graphgym/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								explaining_framework/utils/visualizer/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								explaining_framework/utils/visualizer/__init__.py
									
										
									
									
									
										Normal file
									
								
							
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 araison
						araison