Coverage for NeuralTSNE/NeuralTSNE/MnistPlotter/tests/test_mnist.py: 100%
45 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1import io
2from unittest.mock import MagicMock, call, mock_open, patch
4import numpy as np
5import pytest
7import NeuralTSNE.MnistPlotter as plotter
10@pytest.mark.parametrize("is_fashion", [True, False])
11@patch("matplotlib.pyplot.subplots")
12@patch("matplotlib.pyplot.xlabel")
13@patch("matplotlib.pyplot.ylabel")
14@patch("seaborn.scatterplot")
15@patch("matplotlib.pyplot.savefig")
16@patch("matplotlib.pyplot.show")
17def test_plot(
18 mock_plt_show: MagicMock,
19 mock_plt_savefig: MagicMock,
20 mock_scatterplot: MagicMock,
21 mock_plt_ylabel: MagicMock,
22 mock_plt_xlabel: MagicMock,
23 mock_plt_subplots: MagicMock,
24 is_fashion: bool,
25):
26 img_file = "test.png"
27 mock_plt_subplots.return_value = (None, None)
28 plotter.plot(np.array([[1, 2], [3, 4]]), np.array([1, 2]), is_fashion, img_file)
30 mock_scatterplot.assert_called_once()
31 mock_plt_savefig.assert_called_once_with(img_file)
32 mock_plt_show.assert_called_once()
33 mock_plt_xlabel.assert_called_once()
34 mock_plt_ylabel.assert_called_once()
35 mock_plt_subplots.assert_called_once()
38@pytest.mark.parametrize("input_file", ["test_data.txt"])
39@pytest.mark.parametrize("labels_file", ["test_labels.txt", None])
40@pytest.mark.parametrize("is_fashion", [True, False])
41@patch("builtins.open", new_callable=mock_open)
42@patch("numpy.loadtxt")
43@patch("NeuralTSNE.MnistPlotter.mnist_plot.plot")
44def test_plot_from_file(
45 mock_plot: MagicMock,
46 mock_loadtxt: MagicMock,
47 mock_files: MagicMock,
48 input_file: str | None,
49 labels_file: str | None,
50 is_fashion: bool,
51):
52 file_content = "10\n1 2 3\n4 5 6"
53 labels_content = "1\n2"
55 handlers = (
56 [
57 io.StringIO(file_content),
58 io.StringIO(labels_content),
59 ]
60 if labels_file
61 else [io.StringIO(file_content)]
62 )
64 mock_files.side_effect = handlers
66 data_list = list(line.split() for line in file_content.splitlines()[1:])
67 data = np.array(data_list, dtype="float32")
69 labels_list = list(line.split() for line in labels_content.splitlines())
70 labels = np.array(labels_list, dtype="int32") if labels_file else None
72 mock_loadtxt.side_effect = [data, labels]
74 plotter.plot_from_file(input_file, labels_file, is_fashion)
76 if labels_file:
77 mock_loadtxt.assert_has_calls(
78 [call(handlers[0]), call(handlers[1], dtype="int")],
79 any_order=True,
80 )
81 mock_files.assert_has_calls(
82 [call(input_file, "r"), call(labels_file, "r")], any_order=True
83 )
84 else:
85 mock_loadtxt.assert_called_once_with(handlers[0])
86 mock_files.assert_called_once_with(input_file, "r")
88 mock_plot.assert_called_once_with(
89 data, labels, is_fashion, input_file.rsplit(".", 1)[0] + ".png"
90 )