Coverage for NeuralTSNE/NeuralTSNE/TSNE/tests/test_dimensionality_reduction.py: 100%
169 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from unittest.mock import patch
3import pytest
4import pytorch_lightning as L
5import torch
7from NeuralTSNE.TSNE.Modules import DimensionalityReduction
8from NeuralTSNE.TSNE.tests.common import (
9 MyDataset,
10 DataLoaderMock,
11)
13from NeuralTSNE.TSNE.tests.fixtures.parametric_tsne_fixtures import (
14 default_parametric_tsne_instance,
15)
17from NeuralTSNE.TSNE.ParametricTSNE import ParametricTSNE
18from NeuralTSNE.TSNE.tests.fixtures.dimensionality_reduction_fixtures import (
19 default_classifier_instance,
20 classifier_instance,
21)
24@pytest.mark.parametrize(
25 "classifier_instance",
26 [{"shuffle": False, "optimizer": "adam", "lr": 1e-5}],
27 indirect=True,
28)
29def test_classifier_init(classifier_instance):
30 classifier_instance, params, mock_exaggeration_status = classifier_instance
32 assert isinstance(classifier_instance, DimensionalityReduction)
33 assert classifier_instance.tsne == params["tsne"]
34 assert classifier_instance.batch_size == params["tsne"].batch_size
35 assert classifier_instance.model == params["tsne"].model
36 assert classifier_instance.loss_fn == params["tsne"].loss_fn
37 assert (
38 classifier_instance.exaggeration_epochs
39 == params["tsne"].early_exaggeration_epochs
40 )
41 assert (
42 classifier_instance.exaggeration_value
43 == params["tsne"].early_exaggeration_value
44 )
45 assert classifier_instance.shuffle == params["shuffle"]
46 assert classifier_instance.lr == params["lr"]
47 assert classifier_instance.optimizer == params["optimizer"]
48 assert mock_exaggeration_status.call_count == 1
51@pytest.mark.parametrize(
52 "default_classifier_instance",
53 [{"early_exaggeration_epochs": 0}, {"early_exaggeration_epochs": 10}],
54 indirect=True,
55)
56def test_reset_exaggeration_status(default_classifier_instance):
57 classifier_instance, params = default_classifier_instance
58 classifier_instance.reset_exaggeration_status()
60 params = params["tsne_params"]
61 if params["early_exaggeration_epochs"] == 0:
62 assert classifier_instance.has_exaggeration_ended == True
63 else:
64 assert classifier_instance.has_exaggeration_ended == False
67@pytest.mark.parametrize(
68 "optimizer, expected_instance",
69 [
70 ("adam", torch.optim.Adam),
71 ("sgd", torch.optim.SGD),
72 ("rmsprop", torch.optim.RMSprop),
73 ],
74)
75def test_set_optimizer(
76 default_classifier_instance,
77 optimizer: str,
78 expected_instance: torch.optim.Optimizer,
79):
80 classifier_instance, _ = default_classifier_instance
82 returned = classifier_instance._set_optimizer(
83 optimizer, {"lr": classifier_instance.lr}
84 )
85 assert isinstance(returned, expected_instance)
86 assert returned.param_groups[0]["lr"] == classifier_instance.lr
89@pytest.mark.parametrize("optimizer", ["dummy_optimizer", "adom"])
90def test_set_optimizer_invalid(default_classifier_instance, optimizer: str):
91 classifier_instance, _ = default_classifier_instance
93 with pytest.raises(ValueError):
94 classifier_instance._set_optimizer(optimizer, {"lr": classifier_instance.lr})
97def test_predict_step(default_classifier_instance):
98 classifier_instance, params = default_classifier_instance
99 tsne_instance = classifier_instance.tsne
100 num_samples = tsne_instance.batch_size * 10
101 dataset = MyDataset(num_samples, 15)
102 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size)
104 for i, batch in enumerate(test_data):
105 logits = classifier_instance.predict_step(batch, i)
106 assert logits.shape == (
107 tsne_instance.batch_size,
108 params["default_tsne_params"]["n_components"],
109 )
112@pytest.mark.parametrize("has_P_multiplied", [True, False])
113@pytest.mark.parametrize("has_exaggeration_ended", [True, False])
114def test_on_train_epoch_end(
115 default_classifier_instance, has_P_multiplied: bool, has_exaggeration_ended: bool
116):
117 classifier_instance, _ = default_classifier_instance
119 if has_P_multiplied:
120 classifier_instance.P_multiplied = torch.tensor(torch.nan)
121 classifier_instance.has_exaggeration_ended = has_exaggeration_ended
123 classifier_instance.on_train_epoch_end()
125 if has_P_multiplied:
126 assert (
127 hasattr(classifier_instance, "P_multiplied") is not has_exaggeration_ended
128 )
129 else:
130 assert hasattr(classifier_instance, "P_multiplied") is False
133@pytest.mark.parametrize("has_P", [True, False])
134def test_on_train_start(default_classifier_instance, has_P: bool):
135 classifier_instance, _ = default_classifier_instance
136 tsne_instance = classifier_instance.tsne
137 num_samples = tsne_instance.batch_size * 10
138 dataset = MyDataset(num_samples, 15)
139 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size)
141 trainer = L.Trainer(fast_dev_run=True)
143 if has_P:
144 classifier_instance.P = torch.tensor(torch.nan)
146 with (
147 patch.object(ParametricTSNE, "_calculate_P") as mocked_calculate_P,
148 patch.object(
149 DimensionalityReduction, "training_step", autospec=True
150 ) as mocked_training_step,
151 patch.object(DimensionalityReduction, "on_train_epoch_start"),
152 patch.object(DimensionalityReduction, "on_train_epoch_end"),
153 ):
154 mocked_calculate_P.return_value = torch.tensor(torch.nan)
155 mocked_training_step.return_value = None
157 trainer.fit(classifier_instance, test_data)
159 if not has_P:
160 assert mocked_calculate_P.call_count == 1
161 else:
162 assert mocked_calculate_P.call_count == 0
164 assert torch.allclose(
165 classifier_instance.P, torch.tensor(torch.nan), equal_nan=True
166 )
169@pytest.mark.parametrize("epochs", [1, 2, 3])
170@pytest.mark.parametrize("has_exaggeration_ended", [True, False])
171@pytest.mark.parametrize("exaggeration_epochs", [0, 1])
172def test_on_train_epoch_start(
173 default_classifier_instance,
174 epochs: int,
175 has_exaggeration_ended: bool,
176 exaggeration_epochs: int,
177):
178 classifier_instance, params = default_classifier_instance
180 tsne_instance = classifier_instance.tsne
181 num_samples = tsne_instance.batch_size * 10
182 dataset = MyDataset(num_samples, 15)
183 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size)
185 trainer = L.Trainer(max_epochs=epochs, limit_train_batches=1)
187 input_P = torch.ones((num_samples, tsne_instance.batch_size))
188 classifier_instance.P = input_P
190 classifier_instance.has_exaggeration_ended = has_exaggeration_ended
191 classifier_instance.exaggeration_epochs = exaggeration_epochs
193 with (
194 patch.object(DimensionalityReduction, "on_train_start"),
195 patch.object(
196 DimensionalityReduction, "training_step", autospec=True
197 ) as mocked_training_step,
198 patch.object(DimensionalityReduction, "on_train_epoch_end"),
199 ):
200 mocked_training_step.return_value = None
202 trainer.fit(classifier_instance, test_data)
204 if has_exaggeration_ended and exaggeration_epochs == 0:
205 assert torch.allclose(classifier_instance.P_current, input_P)
206 elif has_exaggeration_ended:
207 assert torch.allclose(
208 classifier_instance.P_current,
209 input_P * params["default_tsne_params"]["early_exaggeration_value"],
210 )
212 if (
213 not has_exaggeration_ended
214 and epochs <= exaggeration_epochs
215 and exaggeration_epochs > 0
216 ):
217 assert torch.allclose(
218 classifier_instance.P_current,
219 input_P * params["default_tsne_params"]["early_exaggeration_value"],
220 )
221 elif not has_exaggeration_ended:
222 assert torch.allclose(classifier_instance.P_current, input_P)
223 assert classifier_instance.has_exaggeration_ended is True
226def test_training_step(default_classifier_instance):
227 classifier_instance, params = default_classifier_instance
229 tsne_instance = classifier_instance.tsne
230 num_samples = tsne_instance.batch_size * 10
231 dataset = MyDataset(num_samples, 15)
232 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size)
234 trainer = L.Trainer(fast_dev_run=True, accelerator="cpu")
236 input_P = torch.ones((num_samples, tsne_instance.batch_size))
237 classifier_instance.P = input_P
239 with patch.object(DimensionalityReduction, "on_train_start"):
240 trainer.fit(classifier_instance, test_data)
243@pytest.mark.parametrize("validation_dataloaders_count", [1, 2, 3])
244@pytest.mark.parametrize("has_val_P", [True, False])
245def test_on_validation_start(
246 default_classifier_instance, has_val_P: bool, validation_dataloaders_count: int
247):
248 # TODO: Maybe another way to test this? Try to skip training step if possible. Maybe switch to non-zero tensors as well
249 classifier_instance, _ = default_classifier_instance
250 tsne_instance = classifier_instance.tsne
251 num_samples = tsne_instance.batch_size * 10
252 dataset = MyDataset(num_samples, 15)
253 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size)
254 test_val_data = [
255 DataLoaderMock(dataset, batch_size=tsne_instance.batch_size)
256 for _ in range(validation_dataloaders_count)
257 ]
259 trainer = L.Trainer(fast_dev_run=True, limit_train_batches=0)
261 if has_val_P:
262 classifier_instance.val_P = [
263 torch.tensor(torch.nan) for _ in range(validation_dataloaders_count)
264 ]
266 classifier_instance.P = torch.tensor(torch.nan)
268 with (
269 patch.object(ParametricTSNE, "_calculate_P") as mocked_calculate_P,
270 patch.object(
271 DimensionalityReduction, "validation_step", autospec=True
272 ) as mocked_validation_step,
273 patch.object(
274 DimensionalityReduction, "training_step", autospec=True
275 ) as mocked_training_step,
276 patch.object(DimensionalityReduction, "on_train_epoch_start"),
277 patch.object(DimensionalityReduction, "on_train_epoch_end"),
278 ):
279 mocked_calculate_P.return_value = torch.tensor(torch.nan)
280 mocked_validation_step.return_value = None
281 mocked_training_step.return_value = None
283 trainer.fit(classifier_instance, test_data, test_val_data)
285 if not has_val_P:
286 assert mocked_calculate_P.call_count == validation_dataloaders_count
287 else:
288 assert mocked_calculate_P.call_count == 0
290 returned_val_P = [
291 torch.tensor(torch.nan) for _ in range(validation_dataloaders_count)
292 ]
293 for i in range(validation_dataloaders_count):
294 assert torch.allclose(
295 classifier_instance.val_P[i], returned_val_P[i], equal_nan=True
296 )
299@pytest.mark.parametrize("validation_dataloaders_count", [1, 2, 3])
300def test_validation_step(
301 default_classifier_instance, validation_dataloaders_count: int
302):
303 # TODO: Check in actual training
304 classifier_instance, params = default_classifier_instance
306 tsne_instance = classifier_instance.tsne
307 num_samples = tsne_instance.batch_size * 10
308 dataset = MyDataset(num_samples, 15)
309 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size)
310 test_val_data = [
311 DataLoaderMock(dataset, batch_size=tsne_instance.batch_size)
312 for _ in range(validation_dataloaders_count)
313 ]
314 trainer = L.Trainer(fast_dev_run=True, accelerator="cpu")
316 input_P = torch.ones((num_samples, tsne_instance.batch_size))
317 input_val_P = [
318 torch.ones((num_samples, tsne_instance.batch_size))
319 for _ in range(validation_dataloaders_count)
320 ]
321 classifier_instance.P = input_P
322 classifier_instance.val_P = input_val_P
324 with patch.object(DimensionalityReduction, "on_validation_start"):
325 trainer.fit(classifier_instance, test_data, test_val_data)
328@pytest.mark.parametrize(
329 "optimizer, expected_instance",
330 [
331 ("adam", torch.optim.Adam),
332 ("sgd", torch.optim.SGD),
333 ("rmsprop", torch.optim.RMSprop),
334 ],
335)
336def test_configure_optimizers(
337 default_classifier_instance,
338 optimizer: str,
339 expected_instance: torch.optim.Optimizer,
340):
341 classifier_instance, _ = default_classifier_instance
342 classifier_instance.optimizer = optimizer
344 returned = classifier_instance.configure_optimizers()
345 assert isinstance(returned, expected_instance)
346 assert returned.param_groups[0]["lr"] == classifier_instance.lr
349@pytest.mark.parametrize("optimizer", ["dummy_optimizer", "adom"])
350def test_configure_optimizers_invalid(default_classifier_instance, optimizer: str):
351 classifier_instance, _ = default_classifier_instance
352 classifier_instance.optimizer = optimizer
354 with pytest.raises(ValueError):
355 classifier_instance.configure_optimizers()