Coverage for NeuralTSNE/NeuralTSNE/DatasetLoader/tests/test_loader.py: 100%

59 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1from unittest.mock import MagicMock, call, patch 

2 

3import pytest 

4 

5from NeuralTSNE.DatasetLoader import get_datasets as loader 

6 

7 

8@patch("torchvision.datasets.MNIST") 

9def test_get_mnist(mock_mnist: MagicMock): 

10 mock_mnist.side_effect = ["mocked mnist train", "mocked mnist test"] 

11 assert loader.get_mnist() == ("mocked mnist train", "mocked mnist test") 

12 assert mock_mnist.call_count == 2 

13 

14 

15@patch("torchvision.datasets.FashionMNIST") 

16def test_get_fashion_mnist(mock_fashion_mnist: MagicMock): 

17 mock_fashion_mnist.side_effect = [ 

18 "mocked fashion mnist train", 

19 "mocked fashion mnist test", 

20 ] 

21 assert loader.get_fashion_mnist() == ( 

22 "mocked fashion mnist train", 

23 "mocked fashion mnist test", 

24 ) 

25 assert mock_fashion_mnist.call_count == 2 

26 

27 

28def test_get_available_datasets(): 

29 l_dict = loader.__dict__ 

30 available_datasets = [key[4:] for key in l_dict if key.startswith("get")] 

31 available_datasets.remove("dataset") 

32 assert loader._get_available_datasets() == available_datasets 

33 

34 

35@pytest.mark.parametrize("dataset", ["mnist", "fashion_mnist", "abcdef"]) 

36@pytest.mark.parametrize("train_exists", [True, False]) 

37@pytest.mark.parametrize("test_exists", [True, False]) 

38@patch("torch.load") 

39@patch("torch.save") 

40@patch("NeuralTSNE.DatasetLoader.get_datasets.os.path.exists") 

41def test_prepare_dataset( 

42 mock_exists: MagicMock, 

43 mock_save: MagicMock, 

44 mock_load: MagicMock, 

45 dataset: str, 

46 train_exists: bool, 

47 test_exists: bool, 

48): 

49 mock_exists.side_effect = [train_exists, test_exists] 

50 if train_exists and test_exists: 

51 mock_load.side_effect = ["mocked train", "mocked test"] 

52 assert loader.prepare_dataset(dataset) == ("mocked train", "mocked test") 

53 assert mock_load.call_count == 2 

54 assert mock_save.call_count == 0 

55 mock_load.assert_has_calls( 

56 [call(dataset + "_train.data"), call(dataset + "_test.data")] 

57 ) 

58 else: 

59 l_dict = loader.__dict__ 

60 available_datasets = [key[4:] for key in l_dict if key.startswith("get")] 

61 available_datasets.remove("dataset") 

62 if dataset not in available_datasets: 

63 with pytest.raises(KeyError): 

64 loader.prepare_dataset(dataset) 

65 else: 

66 with patch( 

67 "NeuralTSNE.DatasetLoader.get_datasets.get_" + dataset 

68 ) as mock_get: 

69 mock_get.return_value = ("mocked train", "mocked test") 

70 assert loader.prepare_dataset(dataset) == ( 

71 "mocked train", 

72 "mocked test", 

73 ) 

74 assert mock_get.call_count == 1 

75 assert mock_save.call_count == 2 

76 mock_save.assert_has_calls( 

77 [ 

78 call("mocked train", dataset + "_train.data"), 

79 call("mocked test", dataset + "_test.data"), 

80 ] 

81 ) 

82 

83 

84@pytest.mark.parametrize("dataset", ["mnist"]) 

85@pytest.mark.parametrize("is_available", [True, False]) 

86@patch("NeuralTSNE.DatasetLoader.get_datasets.prepare_dataset") 

87@patch("NeuralTSNE.DatasetLoader.get_datasets._get_available_datasets") 

88def test_get_dataset( 

89 mock_available: MagicMock, mock_prepare: MagicMock, dataset: str, is_available: bool 

90): 

91 if not is_available: 

92 mock_available.return_value = [] 

93 else: 

94 mock_available.return_value = [dataset] 

95 mock_prepare.return_value = ("mocked train", "mocked test") 

96 

97 returned = loader.get_dataset(dataset) 

98 

99 if not is_available: 

100 mock_prepare.assert_not_called() 

101 assert returned == (None, None) 

102 else: 

103 mock_prepare.assert_called_once_with(dataset) 

104 assert returned == ("mocked train", "mocked test")