explaining_framework/explaining_framework/utils/explanation/adjust.py
2022-12-29 22:00:39 +01:00

65 lines
1.9 KiB
Python

import copy
from torch import FloatTensor
from torch.nn import ReLU
class Adjust(object):
def __init__(
self,
apply_relu: bool = True,
apply_normalize: bool = True,
apply_project: bool = True,
apply_absolute: bool = False,
):
self.apply_relu = apply_relu
self.apply_normalize = apply_normalize
self.apply_project = apply_project
self.apply_absolute = apply_absolute
if self.apply_absolute and self.apply_relu:
self.apply_relu = False
def forward(self, exp: Explanation) -> Explanation:
exp_ = exp.copy()
_store = exp_.to_dict()
for k, v in _store.items():
if "mask" in k:
if self.apply_relu:
_store[k] = self.relu(v)
elif self.apply_absolute:
_store[k] = self.absolute(v)
elif self.apply_project:
if "edge" in k:
pass
else:
_store[k] = self.project(v)
elif self.apply_normalize:
_store[k] = self.normalize(v)
else:
continue
return exp_
def relu(self, mask: FloatTensor) -> FloatTensor:
relu = ReLU()
mask_ = relu(mask)
return mask_
def normalize(self, mask: FloatTensor) -> FloatTensor:
norm = torch.norm(mask, p="inf")
if norm.item() > 0:
mask_ = mask / norm.item()
return mask_
else:
return mask
def project(self, mask: FloatTensor) -> FloatTensor:
if mask.ndim >= 2:
mask_ = torch.sum(mask, dim=1)
return mask_
else:
return mask
def absolute(self, mask: FloatTensor) -> FloatTensor:
mask_ = torch.abs(mask)
return mask_