Aborting
This commit is contained in:
parent
eca200fe88
commit
ae43fd94f7
|
@ -77,7 +77,8 @@ for epoch in range(1, 2):
|
|||
optimizer.step()
|
||||
target = torch.LongTensor([[0]])
|
||||
|
||||
for kind in ["node"]:
|
||||
for kind in ["node", "graph"]:
|
||||
print(kind)
|
||||
for name in __all__captum + __all__graphxai:
|
||||
if name in __all__captum:
|
||||
explaining_algorithm = CaptumWrapper(name)
|
||||
|
@ -97,7 +98,7 @@ for kind in ["node"]:
|
|||
edge_mask_type="object",
|
||||
),
|
||||
model_config=dict(
|
||||
mode="classification",
|
||||
mode="regression",
|
||||
task_level=kind,
|
||||
return_type="raw",
|
||||
),
|
||||
|
@ -108,7 +109,10 @@ for kind in ["node"]:
|
|||
index=int(target),
|
||||
target=batch.y,
|
||||
)
|
||||
explanation.__setattr__(
|
||||
"model_prediction", explainer.get_prediction(x, edge_index)
|
||||
)
|
||||
print(explanation.__dict__)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
# print(str(e))
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue