Coverage for NeuralTSNE/NeuralTSNE/Utils/Writers/StatWriters/tests/test_stat_writers.py: 100%

64 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1import os 

2import random 

3import string 

4from typing import List, Tuple 

5from unittest.mock import MagicMock, mock_open, patch 

6 

7import numpy as np 

8import pytest 

9import torch 

10 

11from NeuralTSNE.TSNE.tests.common import ( 

12 DataLoaderMock, 

13 MyDataset, 

14 PersistentStringIO, 

15) 

16from NeuralTSNE.Utils.Writers.StatWriters import ( 

17 save_means_and_vars, 

18 save_results, 

19) 

20 

21 

22def get_random_string(length: int): 

23 return "".join(random.choices(string.ascii_letters + string.digits, k=length)) 

24 

25 

26@pytest.fixture(params=[""]) 

27def get_output_filename(request): 

28 suffix = request.param 

29 file_name = get_random_string(10) 

30 yield file_name, suffix 

31 file_to_delete = f"{file_name}{suffix}" 

32 if os.path.exists(file_to_delete): 

33 os.remove(file_to_delete) 

34 

35 

36@pytest.mark.parametrize( 

37 "data, filtered_data", 

38 [ 

39 ( 

40 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], 

41 [[1.0, 2.0], [4.0, 5.0], [7.0, 8.0]], 

42 ), 

43 ( 

44 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], 

45 [[1.0, 3.0], [4.0, 6.0], [7.0, 9.0]], 

46 ), 

47 ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], None), 

48 ], 

49) 

50@patch("builtins.open", new_callable=mock_open) 

51def test_save_means_and_vars( 

52 mock_open: MagicMock, 

53 data: List[List[float]], 

54 filtered_data: List[List[float]] | None, 

55): 

56 file_handle = PersistentStringIO() 

57 mock_open.return_value = file_handle 

58 

59 data_means = np.mean(data, axis=0) 

60 data_vars = np.var(data, axis=0, ddof=1) 

61 

62 filtered_data_means = ( 

63 None if filtered_data is None else np.mean(filtered_data, axis=0) 

64 ) 

65 filtered_data_vars = ( 

66 None if filtered_data is None else np.var(filtered_data, axis=0, ddof=1) 

67 ) 

68 

69 data = torch.tensor(data) 

70 if filtered_data is not None: 

71 filtered_data = torch.tensor(filtered_data) 

72 

73 save_means_and_vars(data, filtered_data) 

74 lines = file_handle.getvalue().splitlines() 

75 

76 assert lines[0].split() == ["column", "mean", "var"] 

77 for i, (mean, var) in enumerate(zip(data_means, data_vars)): 

78 assert lines[i + 1].split() == [f"{i}", f"{mean}", f"{var}"] 

79 if filtered_data is not None: 

80 assert lines[len(data_means) + 2].split() == [ 

81 "filtered_column", 

82 "filtered_mean", 

83 "filtered_var", 

84 ] 

85 for i, (filtered_mean, filtered_var) in enumerate( 

86 zip(filtered_data_means, filtered_data_vars) 

87 ): 

88 assert lines[i + len(data_means) + 3].split() == [ 

89 f"{i}", 

90 f"{filtered_mean}", 

91 f"{filtered_var}", 

92 ] 

93 

94 

95@pytest.mark.parametrize("batch_shape", [(5, 2), None]) 

96@pytest.mark.parametrize("step", [30, 45, 2]) 

97def test_save_results( 

98 get_output_filename: str, batch_shape: Tuple[int | int] | None, step: int 

99): 

100 TQDM_DISABLE = 1 

101 filename, _ = get_output_filename 

102 args = {"o": filename, "step": step} 

103 

104 test_data = None 

105 

106 if batch_shape: 

107 num_samples = batch_shape[0] * 10 

108 dataset = MyDataset(num_samples, batch_shape[1]) 

109 test_data = DataLoaderMock(dataset, batch_size=2) 

110 

111 entries_num = random.randint(20, 500) 

112 Y = [ 

113 [(random.random(), random.random()) for _ in range(entries_num)] 

114 for _ in range(2) 

115 ] 

116 

117 save_results(args, test_data, Y) 

118 

119 if batch_shape is None: 

120 assert os.path.exists(filename) is False 

121 return 

122 

123 with open(filename, "r") as f: 

124 lines = f.readlines() 

125 

126 assert lines[0].strip() == str(step) 

127 expected_lines = ["\t".join(tuple(map(str, item))) for batch in Y for item in batch] 

128 for i in range(1, len(lines)): 

129 assert lines[i].strip() == expected_lines[i - 1]