Coverage for NeuralTSNE/NeuralTSNE/MnistPlotter/mnist_plot.py: 97%

25 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1import matplotlib.pyplot as plt 

2import numpy as np 

3import seaborn as sns 

4 

5 

6def plot( 

7 data: np.ndarray, labels: np.ndarray, is_fashion: bool = False, img_file: str = None 

8) -> None: 

9 """ 

10 Plot t-SNE results of mnist dataset. 

11 

12 Parameters 

13 ---------- 

14 `data` : `np.ndarray` 

15 t-SNE data to be plotted. 

16 `labels` : `np.ndarray` 

17 Labels corresponding to the data points. 

18 `is_fashion` : `bool`, optional 

19 Flag indicating whether the dataset is a fashion dataset. 

20 `img_file` : `str`, optional 

21 File path to save the plot as an image. 

22 

23 Note 

24 ---- 

25 This function plots the t-SNE results with colored points based on the provided labels. 

26 """ 

27 if is_fashion: 

28 classes = [ 

29 "T-shirt/top", 

30 "Trouser", 

31 "Pullover", 

32 "Dress", 

33 "Coat", 

34 "Sandal", 

35 "Shirt", 

36 "Sneaker", 

37 "Bag", 

38 "Ankle boot", 

39 ] 

40 else: 

41 classes = [i for i in range(10)] 

42 

43 plt.subplots(1, 1) 

44 

45 sns.scatterplot( 

46 x=data[:, 0], 

47 y=data[:, 1], 

48 hue=map(lambda x: classes[x], labels[: len(data)]), 

49 palette="Paired", 

50 legend="full", 

51 ) 

52 plt.xlabel("t-SNE 1") 

53 plt.ylabel("t-SNE 2") 

54 

55 if img_file: 55 ↛ 58line 55 didn't jump to line 58 because the condition on line 55 was always true

56 new_name = img_file 

57 plt.savefig(new_name) 

58 plt.show() 

59 

60 

61def plot_from_file(file: str, labels_file: str, is_fashion: bool = False) -> None: 

62 """ 

63 Plot t-SNE results of mnist dataset from file. 

64 

65 Parameters 

66 ---------- 

67 `file` : `str` 

68 File path containing t-SNE data. 

69 `labels_file` : `str` 

70 File path containing labels data. 

71 `is_fashion` : `bool`, optional 

72 Flag indicating whether the dataset is a fashion dataset. 

73 

74 Note 

75 ---- 

76 This function reads t-SNE data and labels from files and plots the results using the `plot` function. 

77 """ 

78 data = None 

79 

80 with open(file, "r") as f: 

81 _ = int(f.readline()) 

82 data = np.loadtxt(f) 

83 

84 labels = None 

85 if labels_file: 

86 with open(labels_file, "r") as f: 

87 labels = np.loadtxt(f, dtype="int") 

88 

89 plot(data, labels, is_fashion, file.rsplit(".", 1)[0] + ".png")