65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
import copy
|
|
|
|
from torch import FloatTensor
|
|
from torch.nn import ReLU
|
|
|
|
class Adjust(object):
|
|
def __init__(
|
|
self,
|
|
apply_relu: bool = True,
|
|
apply_normalize: bool = True,
|
|
apply_project: bool = True,
|
|
apply_absolute: bool = False,
|
|
):
|
|
self.apply_relu = apply_relu
|
|
self.apply_normalize = apply_normalize
|
|
self.apply_project = apply_project
|
|
self.apply_absolute = apply_absolute
|
|
|
|
if self.apply_absolute and self.apply_relu:
|
|
self.apply_relu = False
|
|
|
|
def forward(self, exp: Explanation) -> Explanation:
|
|
exp_ = exp.copy()
|
|
_store = exp_.to_dict()
|
|
for k, v in _store.items():
|
|
if "mask" in k:
|
|
if self.apply_relu:
|
|
_store[k] = self.relu(v)
|
|
elif self.apply_absolute:
|
|
_store[k] = self.absolute(v)
|
|
elif self.apply_project:
|
|
if "edge" in k:
|
|
pass
|
|
else:
|
|
_store[k] = self.project(v)
|
|
elif self.apply_normalize:
|
|
_store[k] = self.normalize(v)
|
|
else:
|
|
continue
|
|
|
|
return exp_
|
|
|
|
def relu(self, mask: FloatTensor) -> FloatTensor:
|
|
relu = ReLU()
|
|
mask_ = relu(mask)
|
|
return mask_
|
|
|
|
def normalize(self, mask: FloatTensor) -> FloatTensor:
|
|
norm = torch.norm(mask, p="inf")
|
|
if norm.item() > 0:
|
|
mask_ = mask / norm.item()
|
|
return mask_
|
|
else:
|
|
return mask
|
|
|
|
def project(self, mask: FloatTensor) -> FloatTensor:
|
|
if mask.ndim >= 2:
|
|
mask_ = torch.sum(mask, dim=1)
|
|
return mask_
|
|
else:
|
|
return mask
|
|
|
|
def absolute(self, mask: FloatTensor) -> FloatTensor:
|
|
mask_ = torch.abs(mask)
|
|
return mask_
|