Fixing bugs and adding new features
This commit is contained in:
		
							parent
							
								
									1ab84c807c
								
							
						
					
					
						commit
						95a30f2c23
					
				
					 1 changed files with 11 additions and 11 deletions
				
			
		|  | @ -1,24 +1,24 @@ | ||||||
| import torch | import torch | ||||||
|  | 
 | ||||||
| from explaining_framework.metric.base import Metric | from explaining_framework.metric.base import Metric | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Sparsity(Metric): | class Sparsity(Metric): | ||||||
|     def __init__(self, name): |     def __init__(self, name): | ||||||
|         super().__init__(name=name) |         super().__init__(name=name) | ||||||
|         self.authorized_metric = ['l0'] |         self.authorized_metric = ["l0"] | ||||||
|         self.metric = self.load_metric(name) |         self.metric = self.load_metric(name) | ||||||
| 
 | 
 | ||||||
|     def load_metric(self,name): |     def load_metric(self, name): | ||||||
|         if name in self.authorized_metric: |         if name in self.authorized_metric: | ||||||
|             if name == 'l0': |             if name == "l0": | ||||||
|                 metric = lambda x : torch.mean(mask.float()).item() |                 metric = lambda x: torch.mean(mask.float()).item() | ||||||
|         else: |         else: | ||||||
|             raise ValueError(f'{name} is not supported yet') |             raise ValueError(f"{name} is not supported yet") | ||||||
| 
 | 
 | ||||||
| 
 |     def forward(self, exp: Explanation) -> float: | ||||||
|     def forward(self, exp:Explanation) -> float: |  | ||||||
|         out = {} |         out = {} | ||||||
|         for k,v in exp.to_dict(): |         for k, v in exp.to_dict(): | ||||||
|             if 'mask' in  |             if "mask" in k and v.dtype == torch.bool: | ||||||
| 
 |                 out[k] = torch.mean(mask.float()).item() | ||||||
|         return torch.mean(mask.float()).item() |         return out | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue
	
	 araison
						araison