16 lines
488 B
Python
16 lines
488 B
Python
import copy
|
|
|
|
from torch import FloatTensor
|
|
from torch.nn import ReLU
|
|
|
|
|
|
def relu_mask(explanation: Explanation) -> Explanation:
|
|
relu = ReLU()
|
|
explanation_store = explanation._store
|
|
raw_data = copy.copy(explanation._store)
|
|
for k, v in explanation_store.items():
|
|
if "mask" in k:
|
|
explanation_store[k] = relu(v)
|
|
explanation.__setattr__("raw_explanation", raw_data)
|
|
explanation.__setattr__("raw_explanation_transform", "relu")
|
|
return explanation
|