From b05e44ba795715587442a7f8f9bdbbd2582fc534 Mon Sep 17 00:00:00 2001 From: araison Date: Wed, 8 Mar 2023 18:59:37 +0100 Subject: [PATCH] Upload to github --- README.md | 41 ++++ scgnn/__init__.py | 0 scgnn/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 142 bytes scgnn/__pycache__/scgnn.cpython-310.pyc | Bin 0 -> 5548 bytes scgnn/scgnn.py | 213 ++++++++++++++++++ .../__pycache__/embedding.cpython-310.pyc | Bin 0 -> 1786 bytes scgnn/utils/embedding.py | 54 +++++ setup.py | 9 + 8 files changed, 317 insertions(+) create mode 100644 README.md create mode 100644 scgnn/__init__.py create mode 100644 scgnn/__pycache__/__init__.cpython-310.pyc create mode 100644 scgnn/__pycache__/scgnn.cpython-310.pyc create mode 100644 scgnn/scgnn.py create mode 100644 scgnn/utils/__pycache__/embedding.cpython-310.pyc create mode 100644 scgnn/utils/embedding.py create mode 100644 setup.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..a3242e6 --- /dev/null +++ b/README.md @@ -0,0 +1,41 @@ +Here is the an example code for using ScoreCAM GNN from the [ScoreCAM GNN : a generalization of an optimal local post-hoc explaining method to any geometric deep learning models](https://arxiv.org/abs/2207.12748) paper + +```python +from torch_geometric.datasets import TUDataset + + dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") + data = dataset[0] + from scgnn.scgnn import SCGNN + + import torch.nn.functional as F + from torch_geometric.nn import GCNConv, global_mean_pool + + + model = Sequential( + "data", + [ + ( + lambda data: (data.x, data.edge_index, data.batch), + "data -> x, edge_index, batch", + ), + (GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"), + (GCNConv(64, dataset.num_classes), "x, edge_index -> x"), + (global_mean_pool, "x, batch -> x"), + ], + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + data = dataset[0].to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + model.eval() + out = model(data) + explainer = SCGNN() + explained = explainer.forward( + model, + data.x, + data.edge_index, + target=2, + interest_map_norm=True, + score_map_norm=True, + ) +``` diff --git a/scgnn/__init__.py b/scgnn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scgnn/__pycache__/__init__.cpython-310.pyc b/scgnn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..316468dca963208078e057c214e0d9a80940cbfb GIT binary patch literal 142 zcmd1j<>g`kf(v}BlR@-j5P=LBfgA@QE@lA|DGb33nv8xc8Hzx{2;!Haenx(7s(!Gi zvwmVxVrFrEo_=aYL9u>ua(Z4~d{GjFp&uWgnU`4-AFo$Xd5gm)H$SB`C)EyQR525f HU||3N^0prR literal 0 HcmV?d00001 diff --git a/scgnn/__pycache__/scgnn.cpython-310.pyc b/scgnn/__pycache__/scgnn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98e983387b2165c5d968521652457ce36b61becd GIT binary patch literal 5548 zcmZ`-O>i8?b)M;&ot+)*E*2n2kP<~LC7IG%(;;YM5z>$|+b4=x!Yfh=ITv9nPhg9W|D$9|?e6I&s0E{$S z{pP*y*FV$!-uGT_+GuzNp8x!Zf9?JAIm7r55{^GEgc~T?f0>5C4Q@t;Pj42PG4ok$ z`BrTEwpq4Sd`Ih4{i=Suag7#T=(mmx1&bf^qcXlKO49FmaeNrbMd@C zAD{A1Y1xSu;?w?V%~zu{@uI(|c{e&6pYzYfFZnOUFZ(aY=l%2XEB-6GtrlH~U-e&& z+kRWiUUV^j&3`R^-G4oP!+!&HGrS&sBfjKcGL5?iZ}8?LgEx1W|0bka-h$K;i#wM8 zmgw-g`%L~1zGXi8X8gAQwrPmB`6<4@Pwz7SvbfBRbB2HAE5G#&*3tF5@shuUUTwa} z&+>B^`=ve;U2%Ez4u5%v@$)3{mm$A$i+eQWRA`-c1{ejzdW ztNW(^p18XATjD)%SIZR~NLQfy+oSKG^}G89Z}W@%wOz~q9^~KI$0>-b;>wQcUlVWb z82{n2%*8GV3qgTKz-!0#Koj(?MzD>&yBqjTvQS;BJ1Qtl7O`7lkwNL5xR<4CB= z#|h+)shm}jWT}LXb1&r+GPp(r<8ZUXROS0&-dk7p?Jy6~Z|2VCI0}=H zhJ`q<#bF}k&1jIyVZI)BDr$CGdn--)!+~<|h1vbphhw2;FpG!|s%u&xqMDjrh1y4A zLqwRkc~4|nI1qmjX4x;Ij{0%Jh!a(FfT#guw%OL|1uL;gEc-EsGJN_^TH&C)SK!o9&-0&GUIeTRI7Pq*K zx6Lcu!Mnn%*mYOgKS7Y(aF~G(P z8T3>6X(%~9C%vz#M1y{!Y+|q}b|#{}a$Rt*5bATKU~uOQ8&hS(+?~ z&2hGr^#(~2$hDFwX?Ofkxj`^YhItSy(z0lc^X68A7T-N=>Dj0~e26D4{S_@R1;gr% z!UDylkp^1PCCi)?%47I>gB6MUME6(sh3=mizJ>0q3#*Y;+USfRo5#y z`0^f|-{C|2bRkU$*h*v1Bquqx4EYLLlp{r)ZitRqE~1UPqghp@7p0k?<>~Klnp>Wp zRaY-}9NMEKj71PAF9^!*244??2a_-=TjXmrCaqsynvUL{a{B84&Bhs&jPTesT+_Ac zwxrq~%893bAZ0%Xb7XP8Rbw>^KbxQ76=8Enz}`fq?>sj8CU<%5Q3L4hp@tDcKC*$~ z4dISzBab)544*~uM7{6w7N0|F|?_={1)tbQB_uM-CbVZ z`mxu>v$`(YX}>@0!4KQR7@!ivDK#FTzP-{*rMPwTUi(vEK^xuL{V3c>C$jyMZnyi1 zM-L5iQx3B4m8`O(Fw0j}ZJ6|h6iE*-#l<*ZS98MzE+H~_Za5B-RK}{A(F9YuS2?za z4tO$;8IJz^ielP$|{qEDx)D%oT#WO=+-P$j!W zmO#{O9!f%zwGfU7*i$yyLdPkM?4+`MmwGt&;m(8Xsi|r~(8HHzL2zVcd=GTw>omSS;2C$6{F9feMM{A1^Y_vnQiPIwCmIZrO~(0=dSKkDJ(Buu zpmj)s7@*nONc3{mOfc6`qnhibahxVWnCEi1HpxY%yd(Y){Jw)BS*yHS6YrX`3!<9=sVO`Z-@LRPbhphH=z9aEw_DOzEO_ zj9-JvjVJ7h{lw(vuC?tP@NAcDSD#oTK>Eax3x)mHhL$X(QWuNrh;anGU31UcH*^c2 z2rb`+PNf{7aCr5c@xEUk3M{DZHXy zG>RIZ-F3H{MH5cmLa>#$u>tNw~MHE4VCu0_ZqV3_0UDJnInI0m!rFFezJ?mXYRSc1wy zlF_{urM)o9u7f%3Cn?BoIkyRb-``r$t~!Nz!P^f--f@q1OqmZ=ZR$3vHg%E%f9Z>m z(3?^;t#u^Nip^Yx$3E7RXX!&q)C6WC|5Wxc7qQg1`a?9TN}ft&qgF4>mG>iZ<2&Ti zk}SJ3TcYsE<5f-x+44Q=Uz?_U%0iZ+YSX=ze@v5qOk|bFkBIyj1n0atE**&)7G%ox-7{=uhF!+iuSGtfIPa?h&R}_RZ~D z?iB8hdE5A8W!o#5?lGg?&G;EBygpoRyI$1ys{010yKlgGyAKw@)7T9}jheg-w`!2a zS!2%w2dzUHGf>+MM~(|_rf4j{IT3$n3X2oOH1qnXQ8X!sCaGxdH8E~8pDmh2 zk)*7LxZao_s%lo=oxtF}BDT=6AEmHYTZ(9+(}+&%t1R{ClvVwB;-jhwavx6dB?-uX zCYVizv`X{7VgbB_YL%_s%nrz|Erhrif{i2x`Wuw&9T4Q9<~(AJ3tMf$0$n5?7ht&* zf0_r>V`msi_H*0XI#pWaNqn+3KCZ56-0aNBH!-q&i%17VRfB-1J><<+mP=)?rD;@V zRb;W{KC>i|WW$ou812TAGGk?72lR?n;~;J7i6}Zuo!X@QEh6NNI&Y)knbC2&2E3|X z`jq{@@UJhK#lB#6^Yi8`JL61c_BpG1pRw;cCVLL;lZDNHsYj)IK#UwA2msEt)sN{G zoQZs^zLdw~r8~?1pWM5%qN)cBUC?Bj=JG1m0ik;L)-tZT8>%&k(zP(c1vE^8F}6Y3 z93!2G8t~tvyQ02jA(w;2BSRgrKJvBlmSr8o-pU%?%@IAGpzjOEt=wx-7_ag0`ahyE zqnOk(B;75x&ePoOE7#kbm)l1f!{xRf?b!`dnLp8_b+;O*(<)k2c5lK%An|xI_=|X$Qre+s!Lzf+6@o?4(ZZWK~i?F zE_|JGnZL>gO-ebZq=Ze94U~E8{EHpoH1pJY3PhcLv2~IUqf9Q4!f7IO0adfQwmXe~ zWw522s5W%%y6WtUL(yiy-ZQyFV^j~s(%$}x`aY)mIUJ6Ls3|sWosa=v4CXT)nK>)D6Hdmp0%t6v(B1Ut8V<%t)Y%I r7=Y4aregsvzHHhE? bool: + task_level = self.model_config.task_level + if task_level not in [ModelTaskLevel.graph]: + logging.error(f"Task level '{task_level.value}' not supported") + return False + + edge_mask_type = self.explainer_config.edge_mask_type + if edge_mask_type not in [MaskType.object, None]: + logging.error(f"Edge mask type '{edge_mask_type.value}' not " f"supported") + return False + + node_mask_type = self.explainer_config.node_mask_type + if node_mask_type not in [ + MaskType.common_attributes, + MaskType.object, + MaskType.attributes, + ]: + logging.error(f"Node mask type '{node_mask_type.value}' not " f"supported.") + return False + + return True + + def forward( + self, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + target, + **kwargs, + ) -> Explanation: + embedding = get_message_passing_embeddings( + model=model, x=x, edge_index=edge_index + ) + + out = model(x=x, edge_index=edge_index) + + if self.target_baseline is None: + c = target + if self.target_baseline == "inference": + c = out.argmax(dim=1).item() + + if self.depth == "last": + score_map = self.get_score_map( + model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c + ) + extra_score_map = None + elif self.depth == "all": + score_map = self.get_score_map( + model=model, x=x, edge_index=edge_index, emb=embedding[-1], c=c + ) + extra_score_map = torch.cat( + [ + self.get_score_map( + model=model, x=x, edge_index=edge_index, emb=emb, c=c + ) + for emb in embedding[:-1] + ], + dim=0, + ) + else: + raise ValueError(f"Depth={self.depth} not implemented yet") + + node_mask = score_map + edge_mask = None + node_feat_mask = None + edge_feat_mask = None + + exp = Explanation( + x=x, + edge_index=edge_index, + y=target, + edge_mask=edge_mask, + node_mask=node_mask, + node_feat_mask=node_feat_mask, + edge_feat_mask=edge_feat_mask, + extra_score_map=extra_score_map, + ) + return exp + + def get_score_map( + self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, emb: Tensor, c: int + ) -> Tensor: + interest_map = emb.clone() + n_nodes, n_features = interest_map.size() + score_map = torch.zeros(n_nodes).to(x.device) + for k in range(n_features): + _x = x.clone() + feat = interest_map[:, k] + if feat.min() == feat.max(): + continue + mask = feat.clone() + if self.interest_map_norm: + mask = (mask - mask.min()).div(mask.max() - mask.min()) + mask = mask.reshape((-1, 1)) + _x = _x * mask + _out = model(x=_x, edge_index=edge_index) + _out = F.softmax(_out, dim=1) + _out = _out.squeeze() + val = float(_out[c]) + score_map = score_map + val * feat + + score_map = F.relu(score_map) + + if self.score_map_norm and score_map.min() != score_map.max(): + score_map = (score_map - score_map.min()).div( + score_map.max() - score_map.min() + ) + return score_map + + +if __name__ == "__main__": + from torch_geometric.datasets import TUDataset + + dataset = TUDataset(root="/tmp/ENZYMES", name="ENZYMES") + data = dataset[0] + + import torch.nn.functional as F + from torch_geometric.nn import GCNConv, global_mean_pool + + # model = torch.nn.ModuleDict( + # { + # "conv1": GCNConv(dataset.num_node_features, 64), + # "conv2": GCNConv(64, dataset.num_classes), + # "gmp": global_mean_pool, + # } + # ) + + model = Sequential( + "data", + [ + ( + lambda data: (data.x, data.edge_index, data.batch), + "data -> x, edge_index, batch", + ), + (GCNConv(dataset.num_node_features, 64), "x, edge_index -> x"), + (GCNConv(64, dataset.num_classes), "x, edge_index -> x"), + (global_mean_pool, "x, batch -> x"), + ], + ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + data = dataset[0].to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + model.eval() + out = model(data) + explainer = SCGNN() + explained = explainer.forward( + model, + data.x, + data.edge_index, + target=2, + interest_map_norm=True, + score_map_norm=True, + ) diff --git a/scgnn/utils/__pycache__/embedding.cpython-310.pyc b/scgnn/utils/__pycache__/embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49f00cc8120481958af2263a18416cada7e2ed65 GIT binary patch literal 1786 zcmbtUPj4GV6rY)$U2lx@rzx#eRLY!497VP`fGVr1h~UzSMu-+6EFov(o%m1IGw#ed z!POcef<$|P3m?FixbzDk@jW=rl@nh9A(i)bosbp@4vaO=J8$Q`_nY7R-s;uWDuMCi zFF$vmMTGo@i_0T~#rrVreK09WX-3-k&oVkUapt$fSdaY`SOdi*=&Bl~scgFQ9%{G)+&b?T!>Qdw<$IMS)id}O4X7}ce? zc{x_#KZdyu)BXx3rGSQv=nRnb$TPs;*f1FKl)EquU3e6IM~-Mwy*A@LG7EDuj~3)Z z^2Pcra+N&FgEMA!^Kc&L(R?N63%a0`Jbu$v^YDzO^q863?zMR>jnSWz;p(gcIo1}0 z636E23-ZG{`g1xUIs1uz54%1i=?W*DoB=8ozeDV&r{Klk=P)69Bd-+MLckHt;yj*T z!F-RCrv&(^CUr*0C&lrtaD6HC#EmCT$kCxp(}C()p^rq8WmN&A)y)!XTL)h`ZMywV zPwJ6$X3%XY)#z$<(s&@PO?vWU*afi%B1@*ySTRs#xIP@V4hqf>8r8yePdRB6P@>%2xQHz(v;-uajHGh5*e6rrbOE$TeQEa+h}hO+H_$*fRM~^eEf?ZgPER!u z3}F8&oxu)){#kqZo(;CIj|K{qV|}HRknINGJsFUPiNyFH~k=CFKdPTU}~ z&5Ov!7``Sb$+sYC(|Gtd$)nXPu|7b-YPC^ zOQc#_p|lhIU>BCm%Z?2OXZvlO*#Mr8_;70MEq>{DzM7iL9@b zV&Eu!AC)ao05h9dvW68peP}j-;-eCixrwi$(rY#l-x``L=)U#ugyO~8x53n|a(0XH zs79-FgR-w{tV$ym7pt2usw{>zX9l+d$Tu)` List[Tensor]: + """Returns the output embeddings of all + :class:`~torch_geometric.nn.conv.MessagePassing` layers in + :obj:`model`. + + Internally, this method registers forward hooks on all + :class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`, + and runs the forward pass of the :obj:`model` by calling + :obj:`model(*args, **kwargs)`. + + Args: + model (torch.nn.Module): The message passing model. + *args: Arguments passed to the model. + **kwargs (optional): Additional keyword arguments passed to the model. + """ + from torch_geometric.nn import MessagePassing + + embeddings: List[Tensor] = [] + + def hook(model: torch.nn.Module, inputs: Any, outputs: Any): + # Clone output in case it will be later modified in-place: + outputs = outputs[0] if isinstance(outputs, tuple) else outputs + assert isinstance(outputs, Tensor) + embeddings.append(outputs.clone()) + + hook_handles = [] + for module in model.modules(): # Register forward hooks: + if isinstance(module, MessagePassing): + hook_handles.append(module.register_forward_hook(hook)) + + if len(hook_handles) == 0: + warnings.warn("The 'model' does not have any 'MessagePassing' layers") + + training = model.training + model.eval() + with torch.no_grad(): + model(*args, **kwargs) + model.train(training) + + for handle in hook_handles: # Remove hooks: + handle.remove() + + return embeddings diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..79bb91f --- /dev/null +++ b/setup.py @@ -0,0 +1,9 @@ +from setuptools import setup + +setup( + name="scgnn", + version="0.1", + description="Official implementation of ScoreCAM GNN for explaining graph neural networks", + packages=["scgnn"], + zip_safe=False, +)