Coverage for NeuralTSNE/NeuralTSNE/Utils/Preprocessing/Filters/filters.py: 83%
10 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from typing import Union
3import numpy as np
4import torch
7def filter_data_by_variance(
8 data: torch.Tensor, variance_threshold: float
9) -> Union[torch.Tensor, None]:
10 """
11 Filter columns of a 2D `torch.Tensor` based on the variance of each column.
13 If the `variance_threshold` is `None`, the function returns `None`, indicating no filtering is performed.
15 Parameters
16 ----------
17 `data` : `torch.Tensor`
18 The input 2D tensor with columns to be filtered.
19 `variance_threshold` : `float`
20 The threshold for column variance. Columns with variance below this threshold will be filtered out.
22 Returns
23 -------
24 `torch.Tensor` | `None`
25 If `variance_threshold` is `None`, returns `None`. Otherwise, returns a new `tensor` with columns filtered based on variance.
27 Note
28 ----
29 - If `variance_threshold` is set to `None`, the function returns `None`, and no filtering is performed.
30 - The function filters columns based on the variance of each column, keeping only those with variance greater than the specified threshold.
31 """
32 if variance_threshold is None: 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true
33 return None
34 column_vars = data.var(axis=0)
35 cols = np.where(column_vars > variance_threshold)[0]
36 filtered_data = data[:, cols]
37 return filtered_data