Coverage for NeuralTSNE/NeuralTSNE/Utils/Writers/LabelWriters/tests/test_label_writers.py: 100%
54 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1import os
2import random
3import string
4from typing import Tuple
5from unittest.mock import MagicMock, mock_open, patch
7import pytest
8import torch
10from NeuralTSNE.TSNE.tests.common import (
11 DataLoaderMock,
12 MyDataset,
13 PersistentStringIO,
14)
15from NeuralTSNE.Utils.Writers.LabelWriters import (
16 save_labels_data,
17 save_torch_labels,
18)
21def get_random_string(length: int):
22 return "".join(random.choices(string.ascii_letters + string.digits, k=length))
25@pytest.fixture(params=[""])
26def get_output_filename(request: pytest.FixtureRequest):
27 suffix = request.param
28 file_name = get_random_string(10)
29 yield file_name, suffix
30 file_to_delete = f"{file_name}{suffix}"
31 if os.path.exists(file_to_delete):
32 os.remove(file_to_delete)
35@pytest.mark.parametrize("output_path", ["output.txt"])
36@patch("builtins.open", new_callable=mock_open)
37def test_save_torch_labels(mock_open: MagicMock, output_path: str):
38 TQDM_DISABLE = 1
39 file_handle = PersistentStringIO()
40 mock_open.return_value = file_handle
42 data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
43 labels = torch.tensor([1, 2, 3])
45 data_set = torch.utils.data.TensorDataset(data, labels)
47 save_torch_labels(output_path, data_set)
49 new_file_path = output_path.replace(".txt", "_labels.txt")
50 mock_open.assert_called_once_with(new_file_path, "w")
52 assert file_handle.getvalue() == "1\n2\n3\n"
55@pytest.mark.parametrize(
56 "get_output_filename",
57 ["_labels.txt"],
58 indirect=["get_output_filename"],
59)
60@pytest.mark.parametrize("num_batches", [3, 5])
61@pytest.mark.parametrize("batch_shape", [(5, 3), (4, 4), None])
62def test_save_labels_data(
63 get_output_filename: str, num_batches: int, batch_shape: Tuple[int, int] | None
64):
65 TQDM_DISABLE = 1
66 filename, suffix = get_output_filename
67 args = {"o": filename}
69 test_data = None
71 if batch_shape:
72 num_samples = batch_shape[0] * 10
73 dataset = MyDataset(num_samples, batch_shape[1], (0, 10), True)
74 test_data = DataLoaderMock(dataset, batch_size=num_batches)
76 save_labels_data(args, test_data)
78 if batch_shape:
79 data = [
80 "\t".join(map(str, row.tolist()))
81 for batch in test_data.batches
82 for tensor in batch
83 for row in tensor
84 ]
86 if batch_shape is None:
87 assert os.path.exists(filename) is False
88 return
90 with open(f"{filename}{suffix}", "r") as f:
91 lines = f.readlines()
93 for i, line in enumerate(lines):
94 assert line.strip() == data[i]