This commit is contained in:
araison 2022-12-17 18:11:01 +01:00
parent eca200fe88
commit ae43fd94f7

View File

@ -77,7 +77,8 @@ for epoch in range(1, 2):
optimizer.step()
target = torch.LongTensor([[0]])
for kind in ["node"]:
for kind in ["node", "graph"]:
print(kind)
for name in __all__captum + __all__graphxai:
if name in __all__captum:
explaining_algorithm = CaptumWrapper(name)
@ -97,7 +98,7 @@ for kind in ["node"]:
edge_mask_type="object",
),
model_config=dict(
mode="classification",
mode="regression",
task_level=kind,
return_type="raw",
),
@ -108,7 +109,10 @@ for kind in ["node"]:
index=int(target),
target=batch.y,
)
explanation.__setattr__(
"model_prediction", explainer.get_prediction(x, edge_index)
)
print(explanation.__dict__)
except Exception as e:
print(str(e))
# print(str(e))
pass