This commit is contained in:
araison 2022-12-17 18:11:01 +01:00
parent eca200fe88
commit ae43fd94f7

View file

@ -77,7 +77,8 @@ for epoch in range(1, 2):
optimizer.step() optimizer.step()
target = torch.LongTensor([[0]]) target = torch.LongTensor([[0]])
for kind in ["node"]: for kind in ["node", "graph"]:
print(kind)
for name in __all__captum + __all__graphxai: for name in __all__captum + __all__graphxai:
if name in __all__captum: if name in __all__captum:
explaining_algorithm = CaptumWrapper(name) explaining_algorithm = CaptumWrapper(name)
@ -97,7 +98,7 @@ for kind in ["node"]:
edge_mask_type="object", edge_mask_type="object",
), ),
model_config=dict( model_config=dict(
mode="classification", mode="regression",
task_level=kind, task_level=kind,
return_type="raw", return_type="raw",
), ),
@ -108,7 +109,10 @@ for kind in ["node"]:
index=int(target), index=int(target),
target=batch.y, target=batch.y,
) )
explanation.__setattr__(
"model_prediction", explainer.get_prediction(x, edge_index)
)
print(explanation.__dict__) print(explanation.__dict__)
except Exception as e: except Exception as e:
print(str(e)) # print(str(e))
pass pass