Coverage for NeuralTSNE/NeuralTSNE/Utils/Preprocessing/tests/test_preprocessing.py: 100%
29 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from typing import List
2from unittest.mock import MagicMock, patch
4import numpy as np
5import pytest
6import torch
8from NeuralTSNE.Utils.Preprocessing import prepare_data
11@pytest.mark.parametrize("variance_threshold", [0.1, 0.5, None])
12@pytest.mark.parametrize(
13 "data",
14 [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[9, 3, 3, 1], [1, 4, 2, 6], [3, 5, 11, 9]]],
15)
16@patch("NeuralTSNE.Utils.Preprocessing.preprocessing.normalize_columns")
17@patch("NeuralTSNE.Utils.Preprocessing.preprocessing.save_means_and_vars")
18@patch("NeuralTSNE.Utils.Preprocessing.preprocessing.filter_data_by_variance")
19def test_prepare_data(
20 mock_filter_data_by_variance: MagicMock,
21 mock_save_means_and_vars: MagicMock,
22 mock_normalize_columns: MagicMock,
23 data: List[List[float]],
24 variance_threshold: float | None,
25):
26 data = np.array(data)
27 filtered = None if variance_threshold is None else data
28 data_t = torch.tensor(data, dtype=torch.float32)
29 mock_filter_data_by_variance.return_value = filtered
30 mock_normalize_columns.return_value = data_t
32 result = prepare_data(variance_threshold, data)
34 mock_filter_data_by_variance.assert_called_once_with(data, variance_threshold)
35 mock_normalize_columns.assert_called_once()
36 normalize_columns_args = mock_normalize_columns.call_args[0]
37 assert np.allclose(normalize_columns_args[0], data_t)
38 mock_save_means_and_vars.assert_called_once()
39 save_means_and_vars_args = mock_save_means_and_vars.call_args[0]
40 np.allclose(save_means_and_vars_args[0], data)
41 if variance_threshold is None:
42 assert save_means_and_vars_args[1] is None
43 else:
44 assert np.allclose(save_means_and_vars_args[1], filtered)
45 assert torch.allclose(result, data_t)