Coverage for NeuralTSNE/NeuralTSNE/TSNE/ParametricTSNE/parametric_tsne.py: 86%
83 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from collections import OrderedDict
2from typing import Callable, List, Tuple, Union
4import torch
5import torchinfo
6from torch import nn
7from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
9from tqdm import tqdm
11from NeuralTSNE.TSNE.Helpers import x2p
12from NeuralTSNE.TSNE.CostFunctions import CostFunctions
13from NeuralTSNE.TSNE.NeuralNetwork import NeuralNetwork, BaseModel
15from NeuralTSNE.Utils import does_sum_up_to
18class ParametricTSNE:
19 """
20 Parametric t-SNE implementation using a neural network model.
22 Parameters
23 ----------
24 `loss_fn` : `str`
25 Loss function for t-SNE. Currently supports `kl_divergence`.
26 `perplexity` : `int`
27 Perplexity parameter for t-SNE.
28 `batch_size` : `int`
29 Batch size for training.
30 `early_exaggeration_epochs` : `int`
31 Number of epochs for early exaggeration.
32 `early_exaggeration_value` : `float`
33 Early exaggeration factor.
34 `max_iterations` : `int`
35 Maximum number of iterations for optimization.
36 `n_components` : `int`, optional
37 Number of components in the output. Defaults to `None`.
38 `features` : `int`, optional
39 Number of input features. Defaults to `None`.
40 `multipliers` : `List[float]`, optional
41 List of multipliers for hidden layers in the neural network. Defaults to `None`.
42 `n_jobs` : `int`, optional
43 Number of workers for data loading. Defaults to `0`.
44 `tolerance` : `float`, optional
45 Tolerance level for convergence. Defaults to `1e-5`.
46 `force_cpu` : `bool`, optional
47 Force using CPU even if GPU is available. Defaults to `False`.
48 `model` : `Union[NeuralNetwork, nn.Module, OrderedDict]`, optional
49 Predefined model. Defaults to `None`.
50 """
52 def __init__(
53 self,
54 loss_fn: str,
55 perplexity: int,
56 batch_size: int,
57 early_exaggeration_epochs: int,
58 early_exaggeration_value: float,
59 max_iterations: int,
60 n_components: Union[int, None] = None,
61 features: Union[int, None] = None,
62 multipliers: Union[List[float], None] = None,
63 n_jobs: int = 0,
64 tolerance: float = 1e-5,
65 force_cpu: bool = False,
66 model: Union[NeuralNetwork, nn.Module, OrderedDict, None] = None,
67 ):
68 if model is None and ( 68 ↛ 71line 68 didn't jump to line 71 because the condition on line 68 was never true
69 features is None or n_components is None or multipliers is None
70 ):
71 raise AttributeError(
72 "Either a model or features, n_components, and multipliers must be provided."
73 )
74 if force_cpu or not torch.cuda.is_available(): 74 ↛ 76line 74 didn't jump to line 76 because the condition on line 74 was always true
75 self.device = torch.device("cpu")
76 elif torch.cuda.is_available():
77 self.device = torch.device("cuda:0")
78 self.model = None
79 if model is None: 79 ↛ 83line 79 didn't jump to line 83 because the condition on line 79 was always true
80 self.model = NeuralNetwork(features, n_components, multipliers).to(
81 self.device
82 )
83 elif isinstance(model, (NeuralNetwork, BaseModel)):
84 self.model = model.to(self.device)
85 elif isinstance(model, (OrderedDict, nn.Sequential)):
86 self.model = NeuralNetwork(pre_filled_layers=model).to(self.device)
88 features = self.model.in_features
90 torchinfo.summary(
91 self.model,
92 input_size=(batch_size, 1, features),
93 col_names=(
94 "input_size",
95 "output_size",
96 "num_params",
97 "kernel_size",
98 "mult_adds",
99 ),
100 )
102 self.perplexity = perplexity
103 self.batch_size = batch_size
104 self.early_exaggeration_epochs = early_exaggeration_epochs
105 self.early_exaggeration_value = early_exaggeration_value
106 self.n_jobs = n_jobs
107 self.tolerance = tolerance
108 self.max_iterations = max_iterations
110 self.loss_fn = self.set_loss_fn(loss_fn)
112 def set_loss_fn(self, loss_fn: str) -> Callable:
113 """
114 Set the loss function based on the provided string.
116 Parameters
117 ----------
118 `loss_fn` : `str`
119 String indicating the desired loss function.
121 Returns
122 -------
123 `Callable`
124 Corresponding loss function.
126 Note
127 ----
128 Currently supports `kl_divergence` as the loss function.
129 """
130 fn = CostFunctions(loss_fn)
131 self.loss_fn = fn
132 return fn
134 def save_model(self, filename: str):
135 """
136 Save the model's state dictionary to a file.
138 Parameters
139 ----------
140 `filename` : `str`
141 Name of the file to save the model.
142 """
143 torch.save(self.model.state_dict(), filename)
145 def read_model(self, filename: str):
146 """
147 Load the model's state dictionary from a file.
149 Parameters
150 ----------
151 `filename` : `str`
152 Name of the file to load the model.
153 """
154 self.model.load_state_dict(torch.load(filename))
156 def split_dataset(
157 self,
158 X: torch.Tensor,
159 y: torch.Tensor = None,
160 train_size: float = None,
161 test_size: float = None,
162 ) -> Tuple[Union[DataLoader, None], Union[DataLoader, None]]:
163 """
164 Split the dataset into training and testing set
166 Parameters
167 ----------
168 `X` : `torch.Tensor`
169 Input data tensor.
170 `y` : `torch.Tensor`, optional
171 Target tensor. Default is `None`.
172 `train_size` : `float`, optional
173 Proportion of the dataset to include in the training set.
174 `test_size` : `float`, optional
175 Proportion of the dataset to include in the testing set.
177 Returns
178 -------
179 `Tuple[DataLoader | None, DataLoader | None]`
180 Tuple containing training and testing dataloaders.
182 Note
183 ----
184 Splits the input data into training and testing sets, and returns corresponding dataloaders.
185 """
186 train_size, test_size = self._determine_train_test_split(train_size, test_size)
187 if y is None:
188 dataset = TensorDataset(X)
189 else:
190 dataset = TensorDataset(X, y)
191 train_size = int(train_size * len(dataset))
192 test_size = len(dataset) - train_size
193 train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
194 if train_size == 0:
195 train_dataset = None
196 if test_size == 0:
197 test_dataset = None
199 return self.create_dataloaders(train_dataset, test_dataset)
201 def _determine_train_test_split(
202 self, train_size: float, test_size: float
203 ) -> Tuple[float, float]:
204 """
205 Determine the proportions of training and testing sets.
207 Parameters
208 ----------
209 `train_size` : `float`
210 Proportion of the dataset to include in the training set.
211 `test_size` : `float`
212 Proportion of the dataset to include in the testing set.
214 Returns
215 -------
216 `Tuple[float, float]`
217 Tuple containing the determined proportions.
218 """
219 if train_size is None and test_size is None:
220 train_size = 0.8
221 test_size = 1 - train_size
222 elif train_size is None:
223 train_size = 1 - test_size
224 elif test_size is None:
225 test_size = 1 - train_size
226 elif not does_sum_up_to(train_size, test_size, 1):
227 test_size = 1 - train_size
228 return train_size, test_size
230 def create_dataloaders(
231 self, train: Dataset, test: Dataset
232 ) -> Tuple[Union[DataLoader, None], Union[DataLoader, None]]:
233 """
234 Create dataloaders for training and testing sets.
236 Parameters
237 ----------
238 `train` : `Dataset`
239 Training dataset.
240 `test` : `Dataset`
241 Testing dataset.
243 Returns
244 -------
245 `Tuple[DataLoader | None, DataLoader | None]`
246 Tuple containing training and testing dataloaders.
247 """
248 train_loader = (
249 DataLoader(
250 train,
251 batch_size=self.batch_size,
252 drop_last=True,
253 pin_memory=False if self.device == "cpu" else True,
254 num_workers=self.n_jobs if self.device == "cpu" else 0,
255 )
256 if train is not None
257 else None
258 )
259 test_loader = (
260 DataLoader(
261 test,
262 batch_size=self.batch_size,
263 drop_last=False,
264 pin_memory=False if self.device == "cpu" else True,
265 num_workers=self.n_jobs if self.device == "cpu" else 0,
266 )
267 if test is not None
268 else None
269 )
270 return train_loader, test_loader
272 def _calculate_P(self, dataloader: DataLoader) -> torch.Tensor:
273 """
274 Calculate joint probability matrix P.
276 Parameters
277 ----------
278 `dataloader` : `DataLoader`
279 Dataloader for the dataset.
281 Returns
282 -------
283 `torch.Tensor`
284 Joint probability matrix P.
285 """
286 n = len(dataloader.dataset)
287 P = torch.zeros((n, self.batch_size), device=self.device)
288 for i, (X, *_) in tqdm(
289 enumerate(dataloader),
290 unit="batch",
291 total=len(dataloader),
292 desc="Calculating P",
293 leave=True,
294 position=0
295 ):
296 batch = x2p(X, self.perplexity, self.tolerance)
297 batch[torch.isnan(batch)] = 0
298 batch = batch + batch.mT
299 batch = batch / batch.sum()
300 batch = torch.maximum(
301 batch.to(self.device), torch.tensor([1e-12], device=self.device)
302 )
303 P[i * self.batch_size : (i + 1) * self.batch_size] = batch
304 return P