Coverage for NeuralTSNE/NeuralTSNE/Utils/Preprocessing/Filters/filters.py: 83%

10 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1from typing import Union 

2 

3import numpy as np 

4import torch 

5 

6 

7def filter_data_by_variance( 

8 data: torch.Tensor, variance_threshold: float 

9) -> Union[torch.Tensor, None]: 

10 """ 

11 Filter columns of a 2D `torch.Tensor` based on the variance of each column. 

12 

13 If the `variance_threshold` is `None`, the function returns `None`, indicating no filtering is performed. 

14 

15 Parameters 

16 ---------- 

17 `data` : `torch.Tensor` 

18 The input 2D tensor with columns to be filtered. 

19 `variance_threshold` : `float` 

20 The threshold for column variance. Columns with variance below this threshold will be filtered out. 

21 

22 Returns 

23 ------- 

24 `torch.Tensor` | `None` 

25 If `variance_threshold` is `None`, returns `None`. Otherwise, returns a new `tensor` with columns filtered based on variance. 

26 

27 Note 

28 ---- 

29 - If `variance_threshold` is set to `None`, the function returns `None`, and no filtering is performed. 

30 - The function filters columns based on the variance of each column, keeping only those with variance greater than the specified threshold. 

31 """ 

32 if variance_threshold is None: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true

33 return None 

34 column_vars = data.var(axis=0) 

35 cols = np.where(column_vars > variance_threshold)[0] 

36 filtered_data = data[:, cols] 

37 return filtered_data