New fixes + reformat
This commit is contained in:
		
							parent
							
								
									09a20b891f
								
							
						
					
					
						commit
						5d2cacfa05
					
				
					 11 changed files with 232 additions and 125 deletions
				
			
		
							
								
								
									
										0
									
								
								stat/graph/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								stat/graph/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										228
									
								
								stat/graph/graph_stat.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										228
									
								
								stat/graph/graph_stat.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,228 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | 
 | ||||||
|  | import types | ||||||
|  | from inspect import getmembers, isfunction, signature | ||||||
|  | 
 | ||||||
|  | import networkx as nx | ||||||
|  | import numpy as np | ||||||
|  | import torch | ||||||
|  | from torch_geometric.data import Data | ||||||
|  | from torch_geometric.utils import to_networkx | ||||||
|  | 
 | ||||||
|  | __maps__ = [ | ||||||
|  |     "adamic_adar_index", | ||||||
|  |     "approximate_current_flow_betweenness_centrality", | ||||||
|  |     "average_clustering", | ||||||
|  |     "average_degree_connectivity", | ||||||
|  |     "average_neighbor_degree", | ||||||
|  |     "average_node_connectivity", | ||||||
|  |     "average_shortest_path_length", | ||||||
|  |     "betweenness_centrality", | ||||||
|  |     "bridges", | ||||||
|  |     "closeness_centrality", | ||||||
|  |     "clustering", | ||||||
|  |     "cn_soundarajan_hopcroft", | ||||||
|  |     "common_neighbor_centrality", | ||||||
|  |     "communicability", | ||||||
|  |     "communicability_betweenness_centrality", | ||||||
|  |     "communicability_exp", | ||||||
|  |     "connected_components", | ||||||
|  |     "constraint", | ||||||
|  |     "core_number", | ||||||
|  |     "current_flow_betweenness_centrality", | ||||||
|  |     "current_flow_closeness_centrality", | ||||||
|  |     "cycle_basis", | ||||||
|  |     "degree_assortativity_coefficient", | ||||||
|  |     "degree_centrality", | ||||||
|  |     "degree_mixing_matrix", | ||||||
|  |     "degree_pearson_correlation_coefficient", | ||||||
|  |     "diameter", | ||||||
|  |     "dispersion", | ||||||
|  |     "dominating_set", | ||||||
|  |     "eccentricity", | ||||||
|  |     "effective_size", | ||||||
|  |     "eigenvector_centrality", | ||||||
|  |     "estrada_index", | ||||||
|  |     "generalized_degree", | ||||||
|  |     "global_efficiency", | ||||||
|  |     "global_reaching_centrality", | ||||||
|  |     "graph_clique_number", | ||||||
|  |     "harmonic_centrality", | ||||||
|  |     "has_bridges", | ||||||
|  |     "has_eulerian_path", | ||||||
|  |     "hits", | ||||||
|  |     "information_centrality", | ||||||
|  |     "is_at_free", | ||||||
|  |     "is_biconnected", | ||||||
|  |     "is_bipartite", | ||||||
|  |     "is_chordal", | ||||||
|  |     "is_connected", | ||||||
|  |     "is_directed_acyclic_graph", | ||||||
|  |     "is_distance_regular", | ||||||
|  |     "is_eulerian", | ||||||
|  |     "is_forest", | ||||||
|  |     "is_graphical", | ||||||
|  |     "is_multigraphical", | ||||||
|  |     "is_planar", | ||||||
|  |     "is_pseudographical", | ||||||
|  |     "is_regular", | ||||||
|  |     "is_semieulerian", | ||||||
|  |     "is_strongly_regular", | ||||||
|  |     "is_tree", | ||||||
|  |     "is_valid_degree_sequence_erdos_gallai", | ||||||
|  |     "is_valid_degree_sequence_havel_hakimi", | ||||||
|  |     "jaccard_coefficient", | ||||||
|  |     "k_components", | ||||||
|  |     "k_core", | ||||||
|  |     "katz_centrality", | ||||||
|  |     "load_centrality", | ||||||
|  |     "minimum_cycle_basis", | ||||||
|  |     "minimum_edge_cut", | ||||||
|  |     "minimum_node_cut", | ||||||
|  |     "node_clique_number", | ||||||
|  |     "node_connectivity", | ||||||
|  |     "node_degree_xy", | ||||||
|  |     "number_connected_components", | ||||||
|  |     "number_of_cliques", | ||||||
|  |     "number_of_isolates", | ||||||
|  |     "pagerank", | ||||||
|  |     "periphery", | ||||||
|  |     "preferential_attachment", | ||||||
|  |     "reciprocity", | ||||||
|  |     "resource_allocation_index", | ||||||
|  |     "rich_club_coefficient", | ||||||
|  |     "square_clustering", | ||||||
|  |     "stoer_wagner", | ||||||
|  |     "subgraph_centrality", | ||||||
|  |     "subgraph_centrality_exp", | ||||||
|  |     "transitivity", | ||||||
|  |     "triangles", | ||||||
|  |     "voterank", | ||||||
|  |     "wiener_index", | ||||||
|  |     "algebraic_connectivity", | ||||||
|  |     "degree", | ||||||
|  |     "density", | ||||||
|  |     "normalized_laplacian_spectrum", | ||||||
|  |     "number_of_edges", | ||||||
|  |     "number_of_nodes", | ||||||
|  |     "number_of_selfloops", | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GraphStat(object): | ||||||
|  |     def __init__(self): | ||||||
|  |         self.maps = { | ||||||
|  |             "networkx": self.available_map_networkx(), | ||||||
|  |             "torch_geometric": self.available_map_torch_geometric(), | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     def available_map_networkx(self): | ||||||
|  |         functions_list = getmembers(nx, isfunction) | ||||||
|  | 
 | ||||||
|  |         maps = {} | ||||||
|  |         for func in functions_list: | ||||||
|  |             name, f = func | ||||||
|  |             if name in __maps__: | ||||||
|  |                 maps[name] = f | ||||||
|  |         return maps | ||||||
|  | 
 | ||||||
|  |     def available_map_torch_geometric(self): | ||||||
|  |         names = [ | ||||||
|  |             "num_nodes", | ||||||
|  |             "num_edges", | ||||||
|  |             "has_self_loops", | ||||||
|  |             "has_isolated_nodes", | ||||||
|  |             "num_nodes_features", | ||||||
|  |             "y", | ||||||
|  |         ] | ||||||
|  |         maps = { | ||||||
|  |             name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None | ||||||
|  |             for name in names | ||||||
|  |         } | ||||||
|  |         return maps | ||||||
|  | 
 | ||||||
|  |     def __call__(self, data): | ||||||
|  |         data_ = data.__copy__() | ||||||
|  |         datahash = hash(data.__repr__) | ||||||
|  |         stats = {} | ||||||
|  |         for k, v in self.maps.items(): | ||||||
|  |             if k == "networkx": | ||||||
|  |                 _data_ = to_networkx(data) | ||||||
|  |                 _data_ = _data_.to_undirected() | ||||||
|  |             elif k == "torch_geometric": | ||||||
|  |                 _data_ = data.__copy__() | ||||||
|  |             for name, func in v.items(): | ||||||
|  |                 try: | ||||||
|  |                     val = func(_data_) | ||||||
|  |                 except: | ||||||
|  |                     val = None | ||||||
|  |                 if callable(val): | ||||||
|  |                     val = val() | ||||||
|  |                 if isinstance(val, types.GeneratorType): | ||||||
|  |                     try: | ||||||
|  |                         val = dict(val) | ||||||
|  |                     except: | ||||||
|  |                         val = None | ||||||
|  |                 if name == "hits": | ||||||
|  |                     val = val[0] | ||||||
|  |                 if name == "k_components": | ||||||
|  |                     for key, value in val.items(): | ||||||
|  |                         val[key] = list(val[key][0]) | ||||||
|  | 
 | ||||||
|  |                 stats[name] = self.convert(val) | ||||||
|  |         stats["hash"] = datahash | ||||||
|  |         return stats | ||||||
|  | 
 | ||||||
|  |     def convert(self, val, K=4): | ||||||
|  | 
 | ||||||
|  |         if type(val) == set: | ||||||
|  |             val = list(val) | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, torch.Tensor): | ||||||
|  |             val = val.cpu().numpy().tolist() | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, list): | ||||||
|  |             for ind in range(len(val)): | ||||||
|  |                 item = val[ind] | ||||||
|  |                 if isinstance(item, list): | ||||||
|  |                     for subind in range(len(item)): | ||||||
|  |                         subitem = item[subind] | ||||||
|  |                         if type(subitem) == float: | ||||||
|  |                             item[subind] = round(subitem, K) | ||||||
|  | 
 | ||||||
|  |                 else: | ||||||
|  |                     if type(item) == float: | ||||||
|  |                         val[ind] = round(item, K) | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, dict): | ||||||
|  |             for k, v in val.items(): | ||||||
|  |                 if isinstance(v, dict): | ||||||
|  |                     for k1, v1 in v.items(): | ||||||
|  |                         if type(v1) == float: | ||||||
|  |                             v[k1] = round(v1, K) | ||||||
|  |                         if type(v1) == set: | ||||||
|  |                             v[k1] = list(v1) | ||||||
|  | 
 | ||||||
|  |                 else: | ||||||
|  |                     if type(v) == float: | ||||||
|  |                         val[k] = round(v, K) | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, nx.classes.reportviews.DegreeView): | ||||||
|  |             val = np.array(val).tolist() | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, nx.Graph): | ||||||
|  |             val = nx.node_link_data(val) | ||||||
|  |             return val | ||||||
|  | 
 | ||||||
|  |         elif isinstance(val, np.ndarray): | ||||||
|  |             val = np.around(val, decimals=K) | ||||||
|  |             return val.tolist() | ||||||
|  | 
 | ||||||
|  |         else: | ||||||
|  |             return val | ||||||
							
								
								
									
										0
									
								
								tool/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tool/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								tool/explainer/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tool/explainer/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										3
									
								
								tool/explainer/base.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tool/explainer/base.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,3 @@ | ||||||
|  | class BaseExplaining(object): | ||||||
|  |     def __init__(self,model,explainer_name:wq | ||||||
|  | 
 | ||||||
							
								
								
									
										1
									
								
								tool/explainer/from_captum.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tool/explainer/from_captum.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1 @@ | ||||||
|  | from torch_geometric.nn.models.captum import CaptumModel | ||||||
							
								
								
									
										0
									
								
								tool/graphgym/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tool/graphgym/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								tool/visualizer/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tool/visualizer/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
								
								
									
										125
									
								
								utils/stat.py
									
										
									
									
									
								
							
							
						
						
									
										125
									
								
								utils/stat.py
									
										
									
									
									
								
							|  | @ -1,125 +0,0 @@ | ||||||
| #!/usr/bin/env python |  | ||||||
| # -*- coding: utf-8 -*- |  | ||||||
| 
 |  | ||||||
| import logging |  | ||||||
| import multiprocessing as mp |  | ||||||
| import os |  | ||||||
| import threading |  | ||||||
| import time |  | ||||||
| import types |  | ||||||
| from inspect import getmembers, isfunction, signature |  | ||||||
| 
 |  | ||||||
| import networkx as nx |  | ||||||
| from torch_geometric.data import Data |  | ||||||
| from torch_geometric.utils import to_networkx |  | ||||||
| 
 |  | ||||||
| with open("mapping_nx.txt", "r") as file: |  | ||||||
|     BLACK_LIST = [line.rstrip() for line in file] |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class GraphStat(object): |  | ||||||
|     def __init__(self): |  | ||||||
|         self.maps = { |  | ||||||
|             "networkx": self.available_map_networkx(), |  | ||||||
|             "torch_geometric": self.available_map_torch_geometric(), |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     def available_map_networkx(self): |  | ||||||
|         functions_list = getmembers(nx.algorithms, isfunction) |  | ||||||
|         MANUALLY_ADDED = [ |  | ||||||
|             "algebraic_connectivity", |  | ||||||
|             "adjacency_spectrum", |  | ||||||
|             "degree", |  | ||||||
|             "density", |  | ||||||
|             "laplacian_spectrum", |  | ||||||
|             "normalized_laplacian_spectrum", |  | ||||||
|             "number_of_selfloops", |  | ||||||
|             "number_of_edges", |  | ||||||
|             "number_of_nodes", |  | ||||||
|         ] |  | ||||||
|         MANUALLY_ADDED_LIST = [ |  | ||||||
|             item for item in getmembers(nx, isfunction) if item[0] in MANUALLY_ADDED |  | ||||||
|         ] |  | ||||||
|         functions_list = functions_list + MANUALLY_ADDED_LIST |  | ||||||
|         maps = {} |  | ||||||
|         for func in functions_list: |  | ||||||
|             name, f = func |  | ||||||
|             if ( |  | ||||||
|                 name in BLACK_LIST |  | ||||||
|                 or name == "recursive_simple_cycles" |  | ||||||
|                 or "triad" in name |  | ||||||
|                 or "weisfeiler" in name |  | ||||||
|                 or "dfs" in name |  | ||||||
|                 or "trophic" in name |  | ||||||
|                 or "recursive" in name |  | ||||||
|                 or "scipy" in name |  | ||||||
|                 or "numpy" in name |  | ||||||
|                 or "sigma" in name |  | ||||||
|                 or "omega" in name |  | ||||||
|                 or "all_" in name |  | ||||||
|             ): |  | ||||||
|                 continue |  | ||||||
|             else: |  | ||||||
|                 maps[name] = f |  | ||||||
|         return maps |  | ||||||
| 
 |  | ||||||
|     def available_map_torch_geometric(self): |  | ||||||
|         names = [ |  | ||||||
|             "num_nodes", |  | ||||||
|             "num_edges", |  | ||||||
|             "has_self_loops", |  | ||||||
|             "has_isolated_nodes", |  | ||||||
|             "num_nodes_features", |  | ||||||
|             "y", |  | ||||||
|         ] |  | ||||||
|         maps = { |  | ||||||
|             name: lambda x, name=name: x.__getattr__(name) if hasattr(x, name) else None |  | ||||||
|             for name in names |  | ||||||
|         } |  | ||||||
|         return maps |  | ||||||
| 
 |  | ||||||
|     def __call__(self, data): |  | ||||||
|         data_ = data.__copy__() |  | ||||||
|         stats = {} |  | ||||||
|         for k, v in self.maps.items(): |  | ||||||
|             if k == "networkx": |  | ||||||
|                 _data_ = to_networkx(data) |  | ||||||
|                 _data_ = _data_.to_undirected() |  | ||||||
|             elif k == "torch_geometric": |  | ||||||
|                 _data_ = data.__copy__() |  | ||||||
|             for name, f in v.items(): |  | ||||||
|                 if f is None: |  | ||||||
|                     stats[name] = None |  | ||||||
|                     continue |  | ||||||
|                 else: |  | ||||||
|                     try: |  | ||||||
|                         t0 = time.time() |  | ||||||
|                         val = f(_data_) |  | ||||||
|                         t1 = time.time() |  | ||||||
|                         delta = t1 - t0 |  | ||||||
|                     except Exception as e: |  | ||||||
|                         print(name, e) |  | ||||||
|                         with open(f"{name}.txt", "w") as f: |  | ||||||
|                             f.write(str(e)) |  | ||||||
|                         # print(name, round(delta, 4)) |  | ||||||
|                         # if callable(val) and k == "torch_geometric": |  | ||||||
|                         # val = val() |  | ||||||
|                         # if isinstance(val, types.GeneratorType): |  | ||||||
|                         # val = list(val) |  | ||||||
|                     # stats[name] = val |  | ||||||
|         return stats |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| from torch_geometric.datasets import KarateClub, Planetoid |  | ||||||
| 
 |  | ||||||
| d = Planetoid(root="/tmp/", name="Cora") |  | ||||||
| # d = KarateClub() |  | ||||||
| a = d[0] |  | ||||||
| st = GraphStat() |  | ||||||
| stat = st(a) |  | ||||||
| for k, v in stat.items(): |  | ||||||
|     print("---------") |  | ||||||
|     print("Name:", k) |  | ||||||
|     print("Type:", type(v)) |  | ||||||
|     print("Val:", v) |  | ||||||
|     print("---------") |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 araison
						araison