explaining_framework/explaining_framework/utils/explanation/adjust.py
2022-12-20 16:09:03 +01:00

16 lines
488 B
Python

import copy
from torch import FloatTensor
from torch.nn import ReLU
def relu_mask(explanation: Explanation) -> Explanation:
relu = ReLU()
explanation_store = explanation._store
raw_data = copy.copy(explanation._store)
for k, v in explanation_store.items():
if "mask" in k:
explanation_store[k] = relu(v)
explanation.__setattr__("raw_explanation", raw_data)
explanation.__setattr__("raw_explanation_transform", "relu")
return explanation