Source code for NeuralTSNE.MnistPlotter.mnist_plot

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


[docs] def plot( data: np.ndarray, labels: np.ndarray, is_fashion: bool = False, img_file: str = None ) -> None: """ Plot t-SNE results of mnist dataset. Parameters ---------- `data` : `np.ndarray` t-SNE data to be plotted. `labels` : `np.ndarray` Labels corresponding to the data points. `is_fashion` : `bool`, optional Flag indicating whether the dataset is a fashion dataset. `img_file` : `str`, optional File path to save the plot as an image. Note ---- This function plots the t-SNE results with colored points based on the provided labels. """ if is_fashion: classes = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot", ] else: classes = [i for i in range(10)] plt.subplots(1, 1) sns.scatterplot( x=data[:, 0], y=data[:, 1], hue=map(lambda x: classes[x], labels[: len(data)]), palette="Paired", legend="full", ) plt.xlabel("t-SNE 1") plt.ylabel("t-SNE 2") if img_file: new_name = img_file plt.savefig(new_name) plt.show()
[docs] def plot_from_file(file: str, labels_file: str, is_fashion: bool = False) -> None: """ Plot t-SNE results of mnist dataset from file. Parameters ---------- `file` : `str` File path containing t-SNE data. `labels_file` : `str` File path containing labels data. `is_fashion` : `bool`, optional Flag indicating whether the dataset is a fashion dataset. Note ---- This function reads t-SNE data and labels from files and plots the results using the `plot` function. """ data = None with open(file, "r") as f: _ = int(f.readline()) data = np.loadtxt(f) labels = None if labels_file: with open(labels_file, "r") as f: labels = np.loadtxt(f, dtype="int") plot(data, labels, is_fashion, file.rsplit(".", 1)[0] + ".png")