Coverage for NeuralTSNE/NeuralTSNE/DatasetLoader/tests/test_loader.py: 100%
59 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from unittest.mock import MagicMock, call, patch
3import pytest
5from NeuralTSNE.DatasetLoader import get_datasets as loader
8@patch("torchvision.datasets.MNIST")
9def test_get_mnist(mock_mnist: MagicMock):
10 mock_mnist.side_effect = ["mocked mnist train", "mocked mnist test"]
11 assert loader.get_mnist() == ("mocked mnist train", "mocked mnist test")
12 assert mock_mnist.call_count == 2
15@patch("torchvision.datasets.FashionMNIST")
16def test_get_fashion_mnist(mock_fashion_mnist: MagicMock):
17 mock_fashion_mnist.side_effect = [
18 "mocked fashion mnist train",
19 "mocked fashion mnist test",
20 ]
21 assert loader.get_fashion_mnist() == (
22 "mocked fashion mnist train",
23 "mocked fashion mnist test",
24 )
25 assert mock_fashion_mnist.call_count == 2
28def test_get_available_datasets():
29 l_dict = loader.__dict__
30 available_datasets = [key[4:] for key in l_dict if key.startswith("get")]
31 available_datasets.remove("dataset")
32 assert loader._get_available_datasets() == available_datasets
35@pytest.mark.parametrize("dataset", ["mnist", "fashion_mnist", "abcdef"])
36@pytest.mark.parametrize("train_exists", [True, False])
37@pytest.mark.parametrize("test_exists", [True, False])
38@patch("torch.load")
39@patch("torch.save")
40@patch("NeuralTSNE.DatasetLoader.get_datasets.os.path.exists")
41def test_prepare_dataset(
42 mock_exists: MagicMock,
43 mock_save: MagicMock,
44 mock_load: MagicMock,
45 dataset: str,
46 train_exists: bool,
47 test_exists: bool,
48):
49 mock_exists.side_effect = [train_exists, test_exists]
50 if train_exists and test_exists:
51 mock_load.side_effect = ["mocked train", "mocked test"]
52 assert loader.prepare_dataset(dataset) == ("mocked train", "mocked test")
53 assert mock_load.call_count == 2
54 assert mock_save.call_count == 0
55 mock_load.assert_has_calls(
56 [call(dataset + "_train.data"), call(dataset + "_test.data")]
57 )
58 else:
59 l_dict = loader.__dict__
60 available_datasets = [key[4:] for key in l_dict if key.startswith("get")]
61 available_datasets.remove("dataset")
62 if dataset not in available_datasets:
63 with pytest.raises(KeyError):
64 loader.prepare_dataset(dataset)
65 else:
66 with patch(
67 "NeuralTSNE.DatasetLoader.get_datasets.get_" + dataset
68 ) as mock_get:
69 mock_get.return_value = ("mocked train", "mocked test")
70 assert loader.prepare_dataset(dataset) == (
71 "mocked train",
72 "mocked test",
73 )
74 assert mock_get.call_count == 1
75 assert mock_save.call_count == 2
76 mock_save.assert_has_calls(
77 [
78 call("mocked train", dataset + "_train.data"),
79 call("mocked test", dataset + "_test.data"),
80 ]
81 )
84@pytest.mark.parametrize("dataset", ["mnist"])
85@pytest.mark.parametrize("is_available", [True, False])
86@patch("NeuralTSNE.DatasetLoader.get_datasets.prepare_dataset")
87@patch("NeuralTSNE.DatasetLoader.get_datasets._get_available_datasets")
88def test_get_dataset(
89 mock_available: MagicMock, mock_prepare: MagicMock, dataset: str, is_available: bool
90):
91 if not is_available:
92 mock_available.return_value = []
93 else:
94 mock_available.return_value = [dataset]
95 mock_prepare.return_value = ("mocked train", "mocked test")
97 returned = loader.get_dataset(dataset)
99 if not is_available:
100 mock_prepare.assert_not_called()
101 assert returned == (None, None)
102 else:
103 mock_prepare.assert_called_once_with(dataset)
104 assert returned == ("mocked train", "mocked test")