Coverage for NeuralTSNE/NeuralTSNE/Utils/Writers/StatWriters/tests/test_stat_writers.py: 100%
64 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1import os
2import random
3import string
4from typing import List, Tuple
5from unittest.mock import MagicMock, mock_open, patch
7import numpy as np
8import pytest
9import torch
11from NeuralTSNE.TSNE.tests.common import (
12 DataLoaderMock,
13 MyDataset,
14 PersistentStringIO,
15)
16from NeuralTSNE.Utils.Writers.StatWriters import (
17 save_means_and_vars,
18 save_results,
19)
22def get_random_string(length: int):
23 return "".join(random.choices(string.ascii_letters + string.digits, k=length))
26@pytest.fixture(params=[""])
27def get_output_filename(request):
28 suffix = request.param
29 file_name = get_random_string(10)
30 yield file_name, suffix
31 file_to_delete = f"{file_name}{suffix}"
32 if os.path.exists(file_to_delete):
33 os.remove(file_to_delete)
36@pytest.mark.parametrize(
37 "data, filtered_data",
38 [
39 (
40 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
41 [[1.0, 2.0], [4.0, 5.0], [7.0, 8.0]],
42 ),
43 (
44 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
45 [[1.0, 3.0], [4.0, 6.0], [7.0, 9.0]],
46 ),
47 ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], None),
48 ],
49)
50@patch("builtins.open", new_callable=mock_open)
51def test_save_means_and_vars(
52 mock_open: MagicMock,
53 data: List[List[float]],
54 filtered_data: List[List[float]] | None,
55):
56 file_handle = PersistentStringIO()
57 mock_open.return_value = file_handle
59 data_means = np.mean(data, axis=0)
60 data_vars = np.var(data, axis=0, ddof=1)
62 filtered_data_means = (
63 None if filtered_data is None else np.mean(filtered_data, axis=0)
64 )
65 filtered_data_vars = (
66 None if filtered_data is None else np.var(filtered_data, axis=0, ddof=1)
67 )
69 data = torch.tensor(data)
70 if filtered_data is not None:
71 filtered_data = torch.tensor(filtered_data)
73 save_means_and_vars(data, filtered_data)
74 lines = file_handle.getvalue().splitlines()
76 assert lines[0].split() == ["column", "mean", "var"]
77 for i, (mean, var) in enumerate(zip(data_means, data_vars)):
78 assert lines[i + 1].split() == [f"{i}", f"{mean}", f"{var}"]
79 if filtered_data is not None:
80 assert lines[len(data_means) + 2].split() == [
81 "filtered_column",
82 "filtered_mean",
83 "filtered_var",
84 ]
85 for i, (filtered_mean, filtered_var) in enumerate(
86 zip(filtered_data_means, filtered_data_vars)
87 ):
88 assert lines[i + len(data_means) + 3].split() == [
89 f"{i}",
90 f"{filtered_mean}",
91 f"{filtered_var}",
92 ]
95@pytest.mark.parametrize("batch_shape", [(5, 2), None])
96@pytest.mark.parametrize("step", [30, 45, 2])
97def test_save_results(
98 get_output_filename: str, batch_shape: Tuple[int | int] | None, step: int
99):
100 TQDM_DISABLE = 1
101 filename, _ = get_output_filename
102 args = {"o": filename, "step": step}
104 test_data = None
106 if batch_shape:
107 num_samples = batch_shape[0] * 10
108 dataset = MyDataset(num_samples, batch_shape[1])
109 test_data = DataLoaderMock(dataset, batch_size=2)
111 entries_num = random.randint(20, 500)
112 Y = [
113 [(random.random(), random.random()) for _ in range(entries_num)]
114 for _ in range(2)
115 ]
117 save_results(args, test_data, Y)
119 if batch_shape is None:
120 assert os.path.exists(filename) is False
121 return
123 with open(filename, "r") as f:
124 lines = f.readlines()
126 assert lines[0].strip() == str(step)
127 expected_lines = ["\t".join(tuple(map(str, item))) for batch in Y for item in batch]
128 for i in range(1, len(lines)):
129 assert lines[i].strip() == expected_lines[i - 1]