Coverage for NeuralTSNE/NeuralTSNE/MnistPlotter/mnist_plot.py: 97%
25 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1import matplotlib.pyplot as plt
2import numpy as np
3import seaborn as sns
6def plot(
7 data: np.ndarray, labels: np.ndarray, is_fashion: bool = False, img_file: str = None
8) -> None:
9 """
10 Plot t-SNE results of mnist dataset.
12 Parameters
13 ----------
14 `data` : `np.ndarray`
15 t-SNE data to be plotted.
16 `labels` : `np.ndarray`
17 Labels corresponding to the data points.
18 `is_fashion` : `bool`, optional
19 Flag indicating whether the dataset is a fashion dataset.
20 `img_file` : `str`, optional
21 File path to save the plot as an image.
23 Note
24 ----
25 This function plots the t-SNE results with colored points based on the provided labels.
26 """
27 if is_fashion:
28 classes = [
29 "T-shirt/top",
30 "Trouser",
31 "Pullover",
32 "Dress",
33 "Coat",
34 "Sandal",
35 "Shirt",
36 "Sneaker",
37 "Bag",
38 "Ankle boot",
39 ]
40 else:
41 classes = [i for i in range(10)]
43 plt.subplots(1, 1)
45 sns.scatterplot(
46 x=data[:, 0],
47 y=data[:, 1],
48 hue=map(lambda x: classes[x], labels[: len(data)]),
49 palette="Paired",
50 legend="full",
51 )
52 plt.xlabel("t-SNE 1")
53 plt.ylabel("t-SNE 2")
55 if img_file: 55 ↛ 58line 55 didn't jump to line 58 because the condition on line 55 was always true
56 new_name = img_file
57 plt.savefig(new_name)
58 plt.show()
61def plot_from_file(file: str, labels_file: str, is_fashion: bool = False) -> None:
62 """
63 Plot t-SNE results of mnist dataset from file.
65 Parameters
66 ----------
67 `file` : `str`
68 File path containing t-SNE data.
69 `labels_file` : `str`
70 File path containing labels data.
71 `is_fashion` : `bool`, optional
72 Flag indicating whether the dataset is a fashion dataset.
74 Note
75 ----
76 This function reads t-SNE data and labels from files and plots the results using the `plot` function.
77 """
78 data = None
80 with open(file, "r") as f:
81 _ = int(f.readline())
82 data = np.loadtxt(f)
84 labels = None
85 if labels_file:
86 with open(labels_file, "r") as f:
87 labels = np.loadtxt(f, dtype="int")
89 plot(data, labels, is_fashion, file.rsplit(".", 1)[0] + ".png")