Aborting
This commit is contained in:
parent
eca200fe88
commit
ae43fd94f7
1 changed files with 7 additions and 3 deletions
|
@ -77,7 +77,8 @@ for epoch in range(1, 2):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
target = torch.LongTensor([[0]])
|
target = torch.LongTensor([[0]])
|
||||||
|
|
||||||
for kind in ["node"]:
|
for kind in ["node", "graph"]:
|
||||||
|
print(kind)
|
||||||
for name in __all__captum + __all__graphxai:
|
for name in __all__captum + __all__graphxai:
|
||||||
if name in __all__captum:
|
if name in __all__captum:
|
||||||
explaining_algorithm = CaptumWrapper(name)
|
explaining_algorithm = CaptumWrapper(name)
|
||||||
|
@ -97,7 +98,7 @@ for kind in ["node"]:
|
||||||
edge_mask_type="object",
|
edge_mask_type="object",
|
||||||
),
|
),
|
||||||
model_config=dict(
|
model_config=dict(
|
||||||
mode="classification",
|
mode="regression",
|
||||||
task_level=kind,
|
task_level=kind,
|
||||||
return_type="raw",
|
return_type="raw",
|
||||||
),
|
),
|
||||||
|
@ -108,7 +109,10 @@ for kind in ["node"]:
|
||||||
index=int(target),
|
index=int(target),
|
||||||
target=batch.y,
|
target=batch.y,
|
||||||
)
|
)
|
||||||
|
explanation.__setattr__(
|
||||||
|
"model_prediction", explainer.get_prediction(x, edge_index)
|
||||||
|
)
|
||||||
print(explanation.__dict__)
|
print(explanation.__dict__)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
# print(str(e))
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Reference in a new issue