Coverage for NeuralTSNE/NeuralTSNE/MnistPlotter/tests/test_mnist.py: 100%

45 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1import io 

2from unittest.mock import MagicMock, call, mock_open, patch 

3 

4import numpy as np 

5import pytest 

6 

7import NeuralTSNE.MnistPlotter as plotter 

8 

9 

10@pytest.mark.parametrize("is_fashion", [True, False]) 

11@patch("matplotlib.pyplot.subplots") 

12@patch("matplotlib.pyplot.xlabel") 

13@patch("matplotlib.pyplot.ylabel") 

14@patch("seaborn.scatterplot") 

15@patch("matplotlib.pyplot.savefig") 

16@patch("matplotlib.pyplot.show") 

17def test_plot( 

18 mock_plt_show: MagicMock, 

19 mock_plt_savefig: MagicMock, 

20 mock_scatterplot: MagicMock, 

21 mock_plt_ylabel: MagicMock, 

22 mock_plt_xlabel: MagicMock, 

23 mock_plt_subplots: MagicMock, 

24 is_fashion: bool, 

25): 

26 img_file = "test.png" 

27 mock_plt_subplots.return_value = (None, None) 

28 plotter.plot(np.array([[1, 2], [3, 4]]), np.array([1, 2]), is_fashion, img_file) 

29 

30 mock_scatterplot.assert_called_once() 

31 mock_plt_savefig.assert_called_once_with(img_file) 

32 mock_plt_show.assert_called_once() 

33 mock_plt_xlabel.assert_called_once() 

34 mock_plt_ylabel.assert_called_once() 

35 mock_plt_subplots.assert_called_once() 

36 

37 

38@pytest.mark.parametrize("input_file", ["test_data.txt"]) 

39@pytest.mark.parametrize("labels_file", ["test_labels.txt", None]) 

40@pytest.mark.parametrize("is_fashion", [True, False]) 

41@patch("builtins.open", new_callable=mock_open) 

42@patch("numpy.loadtxt") 

43@patch("NeuralTSNE.MnistPlotter.mnist_plot.plot") 

44def test_plot_from_file( 

45 mock_plot: MagicMock, 

46 mock_loadtxt: MagicMock, 

47 mock_files: MagicMock, 

48 input_file: str | None, 

49 labels_file: str | None, 

50 is_fashion: bool, 

51): 

52 file_content = "10\n1 2 3\n4 5 6" 

53 labels_content = "1\n2" 

54 

55 handlers = ( 

56 [ 

57 io.StringIO(file_content), 

58 io.StringIO(labels_content), 

59 ] 

60 if labels_file 

61 else [io.StringIO(file_content)] 

62 ) 

63 

64 mock_files.side_effect = handlers 

65 

66 data_list = list(line.split() for line in file_content.splitlines()[1:]) 

67 data = np.array(data_list, dtype="float32") 

68 

69 labels_list = list(line.split() for line in labels_content.splitlines()) 

70 labels = np.array(labels_list, dtype="int32") if labels_file else None 

71 

72 mock_loadtxt.side_effect = [data, labels] 

73 

74 plotter.plot_from_file(input_file, labels_file, is_fashion) 

75 

76 if labels_file: 

77 mock_loadtxt.assert_has_calls( 

78 [call(handlers[0]), call(handlers[1], dtype="int")], 

79 any_order=True, 

80 ) 

81 mock_files.assert_has_calls( 

82 [call(input_file, "r"), call(labels_file, "r")], any_order=True 

83 ) 

84 else: 

85 mock_loadtxt.assert_called_once_with(handlers[0]) 

86 mock_files.assert_called_once_with(input_file, "r") 

87 

88 mock_plot.assert_called_once_with( 

89 data, labels, is_fashion, input_file.rsplit(".", 1)[0] + ".png" 

90 )