Coverage for NeuralTSNE/NeuralTSNE/Utils/Writers/LabelWriters/tests/test_label_writers.py: 100%

54 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1import os 

2import random 

3import string 

4from typing import Tuple 

5from unittest.mock import MagicMock, mock_open, patch 

6 

7import pytest 

8import torch 

9 

10from NeuralTSNE.TSNE.tests.common import ( 

11 DataLoaderMock, 

12 MyDataset, 

13 PersistentStringIO, 

14) 

15from NeuralTSNE.Utils.Writers.LabelWriters import ( 

16 save_labels_data, 

17 save_torch_labels, 

18) 

19 

20 

21def get_random_string(length: int): 

22 return "".join(random.choices(string.ascii_letters + string.digits, k=length)) 

23 

24 

25@pytest.fixture(params=[""]) 

26def get_output_filename(request: pytest.FixtureRequest): 

27 suffix = request.param 

28 file_name = get_random_string(10) 

29 yield file_name, suffix 

30 file_to_delete = f"{file_name}{suffix}" 

31 if os.path.exists(file_to_delete): 

32 os.remove(file_to_delete) 

33 

34 

35@pytest.mark.parametrize("output_path", ["output.txt"]) 

36@patch("builtins.open", new_callable=mock_open) 

37def test_save_torch_labels(mock_open: MagicMock, output_path: str): 

38 TQDM_DISABLE = 1 

39 file_handle = PersistentStringIO() 

40 mock_open.return_value = file_handle 

41 

42 data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 

43 labels = torch.tensor([1, 2, 3]) 

44 

45 data_set = torch.utils.data.TensorDataset(data, labels) 

46 

47 save_torch_labels(output_path, data_set) 

48 

49 new_file_path = output_path.replace(".txt", "_labels.txt") 

50 mock_open.assert_called_once_with(new_file_path, "w") 

51 

52 assert file_handle.getvalue() == "1\n2\n3\n" 

53 

54 

55@pytest.mark.parametrize( 

56 "get_output_filename", 

57 ["_labels.txt"], 

58 indirect=["get_output_filename"], 

59) 

60@pytest.mark.parametrize("num_batches", [3, 5]) 

61@pytest.mark.parametrize("batch_shape", [(5, 3), (4, 4), None]) 

62def test_save_labels_data( 

63 get_output_filename: str, num_batches: int, batch_shape: Tuple[int, int] | None 

64): 

65 TQDM_DISABLE = 1 

66 filename, suffix = get_output_filename 

67 args = {"o": filename} 

68 

69 test_data = None 

70 

71 if batch_shape: 

72 num_samples = batch_shape[0] * 10 

73 dataset = MyDataset(num_samples, batch_shape[1], (0, 10), True) 

74 test_data = DataLoaderMock(dataset, batch_size=num_batches) 

75 

76 save_labels_data(args, test_data) 

77 

78 if batch_shape: 

79 data = [ 

80 "\t".join(map(str, row.tolist())) 

81 for batch in test_data.batches 

82 for tensor in batch 

83 for row in tensor 

84 ] 

85 

86 if batch_shape is None: 

87 assert os.path.exists(filename) is False 

88 return 

89 

90 with open(f"{filename}{suffix}", "r") as f: 

91 lines = f.readlines() 

92 

93 for i, line in enumerate(lines): 

94 assert line.strip() == data[i]