Coverage for NeuralTSNE/NeuralTSNE/DatasetLoader/get_datasets.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1import os 

2from typing import List, Tuple 

3 

4import torch 

5from torch import flatten 

6from torch.utils.data import Dataset 

7from torchvision import datasets 

8from torchvision.transforms import Compose, Lambda, ToTensor 

9 

10 

11def get_mnist() -> Tuple[Dataset, Dataset]: 

12 """ 

13 Retrieves the MNIST dataset from `torchvision`. 

14 

15 Returns 

16 ------- 

17 `Tuple[Dataset, Dataset]` 

18 Tuple containing training and testing datasets. 

19 """ 

20 mnist_dataset_train = datasets.MNIST( 

21 root="data", 

22 train=True, 

23 download=True, 

24 transform=Compose([ToTensor(), Lambda(flatten)]), 

25 ) 

26 

27 mnist_dataset_test = datasets.MNIST( 

28 root="data", 

29 train=False, 

30 download=True, 

31 transform=Compose([ToTensor(), Lambda(flatten)]), 

32 ) 

33 return mnist_dataset_train, mnist_dataset_test 

34 

35 

36def get_fashion_mnist() -> Tuple[Dataset, Dataset]: 

37 """ 

38 Retrieves the Fashion MNIST dataset from `torchvision`. 

39 

40 Returns 

41 ------- 

42 `Tuple[Dataset, Dataset]` 

43 Tuple containing training and testing datasets. 

44 """ 

45 fashion_mnist_dataset_train = datasets.FashionMNIST( 

46 root="data", 

47 train=True, 

48 download=True, 

49 transform=Compose([ToTensor(), Lambda(flatten)]), 

50 ) 

51 

52 fashion_mnist_dataset_test = datasets.FashionMNIST( 

53 root="data", 

54 train=False, 

55 download=True, 

56 transform=Compose([ToTensor(), Lambda(flatten)]), 

57 ) 

58 

59 return fashion_mnist_dataset_train, fashion_mnist_dataset_test 

60 

61 

62def _get_available_datasets() -> List[str]: 

63 """ 

64 Gets list of available datasets. 

65 

66 Returns 

67 ------- 

68 `List[str]` 

69 List of available datasets. 

70 """ 

71 methods = [key[4:] for key in globals().keys() if key.startswith("get")] 

72 methods.remove("dataset") 

73 return methods 

74 

75 

76def prepare_dataset(dataset_name: str) -> Tuple[Dataset, Dataset]: 

77 """ 

78 Loads the dataset from file or creates it if it does not exist. 

79 Returns the training and testing datasets. 

80 

81 Parameters 

82 ---------- 

83 `dataset_name` : `str` 

84 Name of the dataset. 

85 

86 Returns 

87 ------- 

88 `Tuple[Dataset, Dataset]` 

89 Tuple containing training and testing datasets. 

90 """ 

91 if not ( 

92 os.path.exists(dataset_name + "_train.data") 

93 and os.path.exists(dataset_name + "_test.data") 

94 ): 

95 train, test = globals()["get_" + dataset_name]() 

96 torch.save(train, dataset_name + "_train.data") 

97 torch.save(test, dataset_name + "_test.data") 

98 else: 

99 train = torch.load(dataset_name + "_train.data") 

100 test = torch.load(dataset_name + "_test.data") 

101 return train, test 

102 

103 

104def get_dataset(dataset_name: str) -> Tuple[Dataset, Dataset] | Tuple[None, None]: 

105 """ 

106 Gets the dataset from the available datasets. 

107 

108 Parameters 

109 ---------- 

110 `dataset_name` : `str` 

111 Name of the dataset. 

112 

113 Returns 

114 ------- 

115 `Tuple[Dataset, Dataset]` | `Tuple[None, None]` 

116 Tuple containing training and testing datasets  

117 or None if the dataset is not available. 

118 """ 

119 name = dataset_name.lower() 

120 if name in _get_available_datasets(): 

121 return prepare_dataset(name) 

122 return None, None