Coverage for NeuralTSNE/NeuralTSNE/TSNE/Modules/dimensionality_reduction.py: 98%

72 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1from typing import Any, List, Tuple, Union 

2 

3import torch 

4import torch.optim as optim 

5 

6import pytorch_lightning as L 

7 

8from NeuralTSNE.TSNE import ParametricTSNE 

9 

10 

11class DimensionalityReduction(L.LightningModule): 

12 """ 

13 Lightning Module for training a neural network-based 

14 Parametric t-SNE dimensionality reduction model. 

15 

16 Parameters 

17 ---------- 

18 `tsne` : `ParametricTSNE` 

19 Parametric t-SNE model for feature extraction. 

20 `shuffle` : `bool` 

21 Flag indicating whether to shuffle data during training. 

22 `optimizer` : `str`, optional 

23 Optimizer for training. Defaults to `adam`. 

24 `lr` : `float`, optional 

25 Learning rate for the optimizer. Defaults to `1e-3`. 

26 

27 Note 

28 ---- 

29 This class defines a Lightning Module for training a neural network-based 

30 Parametric t-SNE dimensionality reduction model for feature extraction. 

31 It includes methods for the training step, configuring optimizers, and 

32 handling the training process. 

33 """ 

34 

35 def __init__( 

36 self, 

37 tsne: "ParametricTSNE", 

38 shuffle: bool, 

39 optimizer: str = "adam", 

40 lr: float = 1e-3, 

41 ): 

42 super().__init__() 

43 self.tsne = tsne 

44 self.batch_size = tsne.batch_size 

45 self.model = self.tsne.model 

46 self.loss_fn = tsne.loss_fn 

47 self.exaggeration_epochs = tsne.early_exaggeration_epochs 

48 self.exaggeration_value = tsne.early_exaggeration_value 

49 self.shuffle = shuffle 

50 self.lr = lr 

51 self.optimizer = optimizer 

52 self.reset_exaggeration_status() 

53 

54 def reset_exaggeration_status(self): 

55 """ 

56 Reset exaggeration status based on the number of exaggeration epochs. 

57 """ 

58 self.has_exaggeration_ended = True if self.exaggeration_epochs == 0 else False 

59 

60 def training_step( 

61 self, 

62 batch: Union[ 

63 torch.Tensor, Tuple[torch.Tensor, ...], List[Union[torch.Tensor, Any]] 

64 ], 

65 batch_idx: int, 

66 ): 

67 """ 

68 Perform a single training step. 

69 

70 Parameters 

71 ---------- 

72 `batch` : `Union[torch.Tensor, Tuple[torch.Tensor, ...], List[Union[torch.Tensor, Any]]]` 

73 Input batch. 

74 `batch_idx` : `int` 

75 Index of the current batch. 

76 

77 Returns 

78 ------- 

79 `Dict[str, torch.Tensor]` 

80 Dictionary containing the `loss` value. 

81 

82 Note 

83 ---- 

84 This method defines a single training step for the dimensionality reduction model. It computes the loss using 

85 the model's `logits` and the conditional probability matrix `_P_batch`. 

86 """ 

87 x = batch[0] 

88 _P_batch = self.P_current[ 

89 batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size 

90 ] 

91 

92 if self.shuffle: 92 ↛ 98line 92 didn't jump to line 98 because the condition on line 92 was always true

93 p_idxs = torch.randperm(x.shape[0]) 

94 x = x[p_idxs] 

95 _P_batch = _P_batch[p_idxs, :] 

96 _P_batch = _P_batch[:, p_idxs] 

97 

98 logits = self.model(x) 

99 loss = self.loss_fn( 

100 logits, 

101 _P_batch, 

102 {"device": self.tsne.device, "batch_size": self.batch_size}, 

103 ) 

104 self.log( 

105 "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True 

106 ) 

107 return {"loss": loss} 

108 

109 def validation_step( 

110 self, 

111 batch: Union[ 

112 torch.Tensor, Tuple[torch.Tensor, ...], List[Union[torch.Tensor, Any]] 

113 ], 

114 batch_idx: int, 

115 dataloader_idx: Union[int, None] = None, 

116 ): 

117 """ 

118 Perform a single validation step. 

119 

120 Parameters 

121 ---------- 

122 `batch` : `Union[torch.Tensor, Tuple[torch.Tensor, ...], List[Union[torch.Tensor, Any]]]` 

123 Input batch. 

124 `batch_idx` 

125 Index of the current batch. 

126 `dataloader_idx` : optional 

127 Index of the dataloader 

128 

129 Returns 

130 ------- 

131 `Dict[str, torch.Tensor]` 

132 Dictionary containing the `loss` value. 

133 

134 Note 

135 ---- 

136 This method defines a single validation step for the dimensionality reduction model. It computes the loss using 

137 the model's `logits` and the conditional probability matrix `_P_batch`. 

138 """ 

139 x = batch[0] 

140 if dataloader_idx is not None: 

141 _P_batch = self.val_P[dataloader_idx][ 

142 batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size 

143 ] 

144 else: 

145 _P_batch = self.val_P[0][ 

146 batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size 

147 ] 

148 logits = self.model(x) 

149 loss = self.loss_fn( 

150 logits, 

151 _P_batch, 

152 {"device": self.tsne.device, "batch_size": self.batch_size}, 

153 ) 

154 self.log( 

155 "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True 

156 ) 

157 return {"loss": loss} 

158 

159 def _set_optimizer( 

160 self, optimizer: str, optimizer_params: dict 

161 ) -> torch.optim.Optimizer: 

162 """ 

163 Set the optimizer based on the provided string. 

164 

165 Parameters 

166 ---------- 

167 `optimizer` : `str` 

168 String indicating the desired optimizer. 

169 `optimizer_params` : `dict` 

170 Dictionary containing optimizer parameters. 

171 

172 Returns 

173 ------- 

174 `torch.optim.Optimizer` 

175 Initialized optimizer. 

176 

177 Note 

178 ---- 

179 This method initializes and returns the desired optimizer based on the provided string. 

180 """ 

181 if optimizer == "adam": 

182 return optim.Adam(self.model.parameters(), **optimizer_params) 

183 elif optimizer == "sgd": 

184 return optim.SGD(self.model.parameters(), **optimizer_params) 

185 elif optimizer == "rmsprop": 

186 return optim.RMSprop(self.model.parameters(), **optimizer_params) 

187 else: 

188 raise ValueError("Unknown optimizer") 

189 

190 def configure_optimizers(self) -> torch.optim.Optimizer: 

191 """ 

192 Configure the optimizer for training. 

193 

194 Returns 

195 ------- 

196 `torch.optim.Optimizer` 

197 Configured optimizer. 

198 

199 Note 

200 ---- 

201 This method configures and returns the optimizer for training based on the specified parameters. 

202 """ 

203 return self._set_optimizer(self.optimizer, {"lr": self.lr}) 

204 

205 def on_train_start(self) -> None: 

206 """ 

207 Perform actions at the beginning of the training process. 

208 

209 Note 

210 ---- 

211 This method is called at the start of the training process and calculates the joint 

212 probability matrix P based on the training dataloader. 

213 """ 

214 if not hasattr(self, "P"): 

215 self.P = self.tsne._calculate_P(self.trainer.train_dataloader) 

216 

217 def on_train_epoch_start(self) -> None: 

218 """ 

219 Perform actions at the start of each training epoch. 

220 

221 Note 

222 ---- 

223 This method is called at the start of each training epoch. If exaggeration is enabled and has 

224 not ended, it modifies the joint probability matrix for the current epoch. 

225 """ 

226 if self.current_epoch > 0 and self.has_exaggeration_ended: 

227 return 

228 if ( 

229 self.exaggeration_epochs > 0 

230 and self.current_epoch < self.exaggeration_epochs 

231 ): 

232 if not hasattr(self, "P_multiplied"): 232 ↛ 235line 232 didn't jump to line 235 because the condition on line 232 was always true

233 self.P_multiplied = self.P.clone() 

234 self.P_multiplied *= self.exaggeration_value 

235 self.P_current = self.P_multiplied 

236 else: 

237 self.P_current = self.P 

238 self.has_exaggeration_ended = True 

239 

240 def on_train_epoch_end(self) -> None: 

241 """ 

242 Perform actions at the end of each training epoch. 

243 

244 Note 

245 ---- 

246 This method is called at the end of each training epoch. If exaggeration has ended and 

247 P_multiplied exists, it is deleted to free up memory. 

248 """ 

249 if hasattr(self, "P_multiplied") and self.has_exaggeration_ended: 

250 del self.P_multiplied 

251 

252 def on_validation_start(self) -> None: 

253 """ 

254 Perform actions at the beginning of the validation process. 

255 

256 Note 

257 ---- 

258 This method is called at the start of the validation process and calculates the joint 

259 probability matrix P for each validation dataloader. 

260 """ 

261 if not hasattr(self, "val_P"): 

262 self.val_P = [ 

263 self.tsne._calculate_P(loader) 

264 for loader in self.trainer.val_dataloaders 

265 ] 

266 

267 def predict_step(self, batch, batch_idx, dataloader_idx=None): 

268 """ 

269 Perform a single step during the prediction process. 

270 

271 Parameters 

272 ---------- 

273 `batch` 

274 Input batch. 

275 `batch_idx` 

276 Index of the current batch. 

277 `dataloader_idx` : optional 

278 Index of the dataloader 

279 

280 Returns 

281 ------- 

282 `torch.Tensor` 

283 Model predictions. 

284 

285 Note 

286 ---- 

287 This method is called during the prediction process and returns the model's predictions for the input batch. 

288 """ 

289 return self.model(batch[0])