Coverage for NeuralTSNE/NeuralTSNE/TSNE/Modules/dimensionality_reduction.py: 98%
72 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from typing import Any, List, Tuple, Union
3import torch
4import torch.optim as optim
6import pytorch_lightning as L
8from NeuralTSNE.TSNE import ParametricTSNE
11class DimensionalityReduction(L.LightningModule):
12 """
13 Lightning Module for training a neural network-based
14 Parametric t-SNE dimensionality reduction model.
16 Parameters
17 ----------
18 `tsne` : `ParametricTSNE`
19 Parametric t-SNE model for feature extraction.
20 `shuffle` : `bool`
21 Flag indicating whether to shuffle data during training.
22 `optimizer` : `str`, optional
23 Optimizer for training. Defaults to `adam`.
24 `lr` : `float`, optional
25 Learning rate for the optimizer. Defaults to `1e-3`.
27 Note
28 ----
29 This class defines a Lightning Module for training a neural network-based
30 Parametric t-SNE dimensionality reduction model for feature extraction.
31 It includes methods for the training step, configuring optimizers, and
32 handling the training process.
33 """
35 def __init__(
36 self,
37 tsne: "ParametricTSNE",
38 shuffle: bool,
39 optimizer: str = "adam",
40 lr: float = 1e-3,
41 ):
42 super().__init__()
43 self.tsne = tsne
44 self.batch_size = tsne.batch_size
45 self.model = self.tsne.model
46 self.loss_fn = tsne.loss_fn
47 self.exaggeration_epochs = tsne.early_exaggeration_epochs
48 self.exaggeration_value = tsne.early_exaggeration_value
49 self.shuffle = shuffle
50 self.lr = lr
51 self.optimizer = optimizer
52 self.reset_exaggeration_status()
54 def reset_exaggeration_status(self):
55 """
56 Reset exaggeration status based on the number of exaggeration epochs.
57 """
58 self.has_exaggeration_ended = True if self.exaggeration_epochs == 0 else False
60 def training_step(
61 self,
62 batch: Union[
63 torch.Tensor, Tuple[torch.Tensor, ...], List[Union[torch.Tensor, Any]]
64 ],
65 batch_idx: int,
66 ):
67 """
68 Perform a single training step.
70 Parameters
71 ----------
72 `batch` : `Union[torch.Tensor, Tuple[torch.Tensor, ...], List[Union[torch.Tensor, Any]]]`
73 Input batch.
74 `batch_idx` : `int`
75 Index of the current batch.
77 Returns
78 -------
79 `Dict[str, torch.Tensor]`
80 Dictionary containing the `loss` value.
82 Note
83 ----
84 This method defines a single training step for the dimensionality reduction model. It computes the loss using
85 the model's `logits` and the conditional probability matrix `_P_batch`.
86 """
87 x = batch[0]
88 _P_batch = self.P_current[
89 batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size
90 ]
92 if self.shuffle: 92 ↛ 98line 92 didn't jump to line 98 because the condition on line 92 was always true
93 p_idxs = torch.randperm(x.shape[0])
94 x = x[p_idxs]
95 _P_batch = _P_batch[p_idxs, :]
96 _P_batch = _P_batch[:, p_idxs]
98 logits = self.model(x)
99 loss = self.loss_fn(
100 logits,
101 _P_batch,
102 {"device": self.tsne.device, "batch_size": self.batch_size},
103 )
104 self.log(
105 "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
106 )
107 return {"loss": loss}
109 def validation_step(
110 self,
111 batch: Union[
112 torch.Tensor, Tuple[torch.Tensor, ...], List[Union[torch.Tensor, Any]]
113 ],
114 batch_idx: int,
115 dataloader_idx: Union[int, None] = None,
116 ):
117 """
118 Perform a single validation step.
120 Parameters
121 ----------
122 `batch` : `Union[torch.Tensor, Tuple[torch.Tensor, ...], List[Union[torch.Tensor, Any]]]`
123 Input batch.
124 `batch_idx`
125 Index of the current batch.
126 `dataloader_idx` : optional
127 Index of the dataloader
129 Returns
130 -------
131 `Dict[str, torch.Tensor]`
132 Dictionary containing the `loss` value.
134 Note
135 ----
136 This method defines a single validation step for the dimensionality reduction model. It computes the loss using
137 the model's `logits` and the conditional probability matrix `_P_batch`.
138 """
139 x = batch[0]
140 if dataloader_idx is not None:
141 _P_batch = self.val_P[dataloader_idx][
142 batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size
143 ]
144 else:
145 _P_batch = self.val_P[0][
146 batch_idx * self.batch_size : (batch_idx + 1) * self.batch_size
147 ]
148 logits = self.model(x)
149 loss = self.loss_fn(
150 logits,
151 _P_batch,
152 {"device": self.tsne.device, "batch_size": self.batch_size},
153 )
154 self.log(
155 "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
156 )
157 return {"loss": loss}
159 def _set_optimizer(
160 self, optimizer: str, optimizer_params: dict
161 ) -> torch.optim.Optimizer:
162 """
163 Set the optimizer based on the provided string.
165 Parameters
166 ----------
167 `optimizer` : `str`
168 String indicating the desired optimizer.
169 `optimizer_params` : `dict`
170 Dictionary containing optimizer parameters.
172 Returns
173 -------
174 `torch.optim.Optimizer`
175 Initialized optimizer.
177 Note
178 ----
179 This method initializes and returns the desired optimizer based on the provided string.
180 """
181 if optimizer == "adam":
182 return optim.Adam(self.model.parameters(), **optimizer_params)
183 elif optimizer == "sgd":
184 return optim.SGD(self.model.parameters(), **optimizer_params)
185 elif optimizer == "rmsprop":
186 return optim.RMSprop(self.model.parameters(), **optimizer_params)
187 else:
188 raise ValueError("Unknown optimizer")
190 def configure_optimizers(self) -> torch.optim.Optimizer:
191 """
192 Configure the optimizer for training.
194 Returns
195 -------
196 `torch.optim.Optimizer`
197 Configured optimizer.
199 Note
200 ----
201 This method configures and returns the optimizer for training based on the specified parameters.
202 """
203 return self._set_optimizer(self.optimizer, {"lr": self.lr})
205 def on_train_start(self) -> None:
206 """
207 Perform actions at the beginning of the training process.
209 Note
210 ----
211 This method is called at the start of the training process and calculates the joint
212 probability matrix P based on the training dataloader.
213 """
214 if not hasattr(self, "P"):
215 self.P = self.tsne._calculate_P(self.trainer.train_dataloader)
217 def on_train_epoch_start(self) -> None:
218 """
219 Perform actions at the start of each training epoch.
221 Note
222 ----
223 This method is called at the start of each training epoch. If exaggeration is enabled and has
224 not ended, it modifies the joint probability matrix for the current epoch.
225 """
226 if self.current_epoch > 0 and self.has_exaggeration_ended:
227 return
228 if (
229 self.exaggeration_epochs > 0
230 and self.current_epoch < self.exaggeration_epochs
231 ):
232 if not hasattr(self, "P_multiplied"): 232 ↛ 235line 232 didn't jump to line 235 because the condition on line 232 was always true
233 self.P_multiplied = self.P.clone()
234 self.P_multiplied *= self.exaggeration_value
235 self.P_current = self.P_multiplied
236 else:
237 self.P_current = self.P
238 self.has_exaggeration_ended = True
240 def on_train_epoch_end(self) -> None:
241 """
242 Perform actions at the end of each training epoch.
244 Note
245 ----
246 This method is called at the end of each training epoch. If exaggeration has ended and
247 P_multiplied exists, it is deleted to free up memory.
248 """
249 if hasattr(self, "P_multiplied") and self.has_exaggeration_ended:
250 del self.P_multiplied
252 def on_validation_start(self) -> None:
253 """
254 Perform actions at the beginning of the validation process.
256 Note
257 ----
258 This method is called at the start of the validation process and calculates the joint
259 probability matrix P for each validation dataloader.
260 """
261 if not hasattr(self, "val_P"):
262 self.val_P = [
263 self.tsne._calculate_P(loader)
264 for loader in self.trainer.val_dataloaders
265 ]
267 def predict_step(self, batch, batch_idx, dataloader_idx=None):
268 """
269 Perform a single step during the prediction process.
271 Parameters
272 ----------
273 `batch`
274 Input batch.
275 `batch_idx`
276 Index of the current batch.
277 `dataloader_idx` : optional
278 Index of the dataloader
280 Returns
281 -------
282 `torch.Tensor`
283 Model predictions.
285 Note
286 ----
287 This method is called during the prediction process and returns the model's predictions for the input batch.
288 """
289 return self.model(batch[0])