Coverage for NeuralTSNE/NeuralTSNE/DatasetLoader/get_datasets.py: 100%
32 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1import os
2from typing import List, Tuple
4import torch
5from torch import flatten
6from torch.utils.data import Dataset
7from torchvision import datasets
8from torchvision.transforms import Compose, Lambda, ToTensor
11def get_mnist() -> Tuple[Dataset, Dataset]:
12 """
13 Retrieves the MNIST dataset from `torchvision`.
15 Returns
16 -------
17 `Tuple[Dataset, Dataset]`
18 Tuple containing training and testing datasets.
19 """
20 mnist_dataset_train = datasets.MNIST(
21 root="data",
22 train=True,
23 download=True,
24 transform=Compose([ToTensor(), Lambda(flatten)]),
25 )
27 mnist_dataset_test = datasets.MNIST(
28 root="data",
29 train=False,
30 download=True,
31 transform=Compose([ToTensor(), Lambda(flatten)]),
32 )
33 return mnist_dataset_train, mnist_dataset_test
36def get_fashion_mnist() -> Tuple[Dataset, Dataset]:
37 """
38 Retrieves the Fashion MNIST dataset from `torchvision`.
40 Returns
41 -------
42 `Tuple[Dataset, Dataset]`
43 Tuple containing training and testing datasets.
44 """
45 fashion_mnist_dataset_train = datasets.FashionMNIST(
46 root="data",
47 train=True,
48 download=True,
49 transform=Compose([ToTensor(), Lambda(flatten)]),
50 )
52 fashion_mnist_dataset_test = datasets.FashionMNIST(
53 root="data",
54 train=False,
55 download=True,
56 transform=Compose([ToTensor(), Lambda(flatten)]),
57 )
59 return fashion_mnist_dataset_train, fashion_mnist_dataset_test
62def _get_available_datasets() -> List[str]:
63 """
64 Gets list of available datasets.
66 Returns
67 -------
68 `List[str]`
69 List of available datasets.
70 """
71 methods = [key[4:] for key in globals().keys() if key.startswith("get")]
72 methods.remove("dataset")
73 return methods
76def prepare_dataset(dataset_name: str) -> Tuple[Dataset, Dataset]:
77 """
78 Loads the dataset from file or creates it if it does not exist.
79 Returns the training and testing datasets.
81 Parameters
82 ----------
83 `dataset_name` : `str`
84 Name of the dataset.
86 Returns
87 -------
88 `Tuple[Dataset, Dataset]`
89 Tuple containing training and testing datasets.
90 """
91 if not (
92 os.path.exists(dataset_name + "_train.data")
93 and os.path.exists(dataset_name + "_test.data")
94 ):
95 train, test = globals()["get_" + dataset_name]()
96 torch.save(train, dataset_name + "_train.data")
97 torch.save(test, dataset_name + "_test.data")
98 else:
99 train = torch.load(dataset_name + "_train.data")
100 test = torch.load(dataset_name + "_test.data")
101 return train, test
104def get_dataset(dataset_name: str) -> Tuple[Dataset, Dataset] | Tuple[None, None]:
105 """
106 Gets the dataset from the available datasets.
108 Parameters
109 ----------
110 `dataset_name` : `str`
111 Name of the dataset.
113 Returns
114 -------
115 `Tuple[Dataset, Dataset]` | `Tuple[None, None]`
116 Tuple containing training and testing datasets
117 or None if the dataset is not available.
118 """
119 name = dataset_name.lower()
120 if name in _get_available_datasets():
121 return prepare_dataset(name)
122 return None, None