Coverage for NeuralTSNE/NeuralTSNE/TSNE/tests/test_dimensionality_reduction.py: 100%

169 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1from unittest.mock import patch 

2 

3import pytest 

4import pytorch_lightning as L 

5import torch 

6 

7from NeuralTSNE.TSNE.Modules import DimensionalityReduction 

8from NeuralTSNE.TSNE.tests.common import ( 

9 MyDataset, 

10 DataLoaderMock, 

11) 

12 

13from NeuralTSNE.TSNE.tests.fixtures.parametric_tsne_fixtures import ( 

14 default_parametric_tsne_instance, 

15) 

16 

17from NeuralTSNE.TSNE.ParametricTSNE import ParametricTSNE 

18from NeuralTSNE.TSNE.tests.fixtures.dimensionality_reduction_fixtures import ( 

19 default_classifier_instance, 

20 classifier_instance, 

21) 

22 

23 

24@pytest.mark.parametrize( 

25 "classifier_instance", 

26 [{"shuffle": False, "optimizer": "adam", "lr": 1e-5}], 

27 indirect=True, 

28) 

29def test_classifier_init(classifier_instance): 

30 classifier_instance, params, mock_exaggeration_status = classifier_instance 

31 

32 assert isinstance(classifier_instance, DimensionalityReduction) 

33 assert classifier_instance.tsne == params["tsne"] 

34 assert classifier_instance.batch_size == params["tsne"].batch_size 

35 assert classifier_instance.model == params["tsne"].model 

36 assert classifier_instance.loss_fn == params["tsne"].loss_fn 

37 assert ( 

38 classifier_instance.exaggeration_epochs 

39 == params["tsne"].early_exaggeration_epochs 

40 ) 

41 assert ( 

42 classifier_instance.exaggeration_value 

43 == params["tsne"].early_exaggeration_value 

44 ) 

45 assert classifier_instance.shuffle == params["shuffle"] 

46 assert classifier_instance.lr == params["lr"] 

47 assert classifier_instance.optimizer == params["optimizer"] 

48 assert mock_exaggeration_status.call_count == 1 

49 

50 

51@pytest.mark.parametrize( 

52 "default_classifier_instance", 

53 [{"early_exaggeration_epochs": 0}, {"early_exaggeration_epochs": 10}], 

54 indirect=True, 

55) 

56def test_reset_exaggeration_status(default_classifier_instance): 

57 classifier_instance, params = default_classifier_instance 

58 classifier_instance.reset_exaggeration_status() 

59 

60 params = params["tsne_params"] 

61 if params["early_exaggeration_epochs"] == 0: 

62 assert classifier_instance.has_exaggeration_ended == True 

63 else: 

64 assert classifier_instance.has_exaggeration_ended == False 

65 

66 

67@pytest.mark.parametrize( 

68 "optimizer, expected_instance", 

69 [ 

70 ("adam", torch.optim.Adam), 

71 ("sgd", torch.optim.SGD), 

72 ("rmsprop", torch.optim.RMSprop), 

73 ], 

74) 

75def test_set_optimizer( 

76 default_classifier_instance, 

77 optimizer: str, 

78 expected_instance: torch.optim.Optimizer, 

79): 

80 classifier_instance, _ = default_classifier_instance 

81 

82 returned = classifier_instance._set_optimizer( 

83 optimizer, {"lr": classifier_instance.lr} 

84 ) 

85 assert isinstance(returned, expected_instance) 

86 assert returned.param_groups[0]["lr"] == classifier_instance.lr 

87 

88 

89@pytest.mark.parametrize("optimizer", ["dummy_optimizer", "adom"]) 

90def test_set_optimizer_invalid(default_classifier_instance, optimizer: str): 

91 classifier_instance, _ = default_classifier_instance 

92 

93 with pytest.raises(ValueError): 

94 classifier_instance._set_optimizer(optimizer, {"lr": classifier_instance.lr}) 

95 

96 

97def test_predict_step(default_classifier_instance): 

98 classifier_instance, params = default_classifier_instance 

99 tsne_instance = classifier_instance.tsne 

100 num_samples = tsne_instance.batch_size * 10 

101 dataset = MyDataset(num_samples, 15) 

102 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size) 

103 

104 for i, batch in enumerate(test_data): 

105 logits = classifier_instance.predict_step(batch, i) 

106 assert logits.shape == ( 

107 tsne_instance.batch_size, 

108 params["default_tsne_params"]["n_components"], 

109 ) 

110 

111 

112@pytest.mark.parametrize("has_P_multiplied", [True, False]) 

113@pytest.mark.parametrize("has_exaggeration_ended", [True, False]) 

114def test_on_train_epoch_end( 

115 default_classifier_instance, has_P_multiplied: bool, has_exaggeration_ended: bool 

116): 

117 classifier_instance, _ = default_classifier_instance 

118 

119 if has_P_multiplied: 

120 classifier_instance.P_multiplied = torch.tensor(torch.nan) 

121 classifier_instance.has_exaggeration_ended = has_exaggeration_ended 

122 

123 classifier_instance.on_train_epoch_end() 

124 

125 if has_P_multiplied: 

126 assert ( 

127 hasattr(classifier_instance, "P_multiplied") is not has_exaggeration_ended 

128 ) 

129 else: 

130 assert hasattr(classifier_instance, "P_multiplied") is False 

131 

132 

133@pytest.mark.parametrize("has_P", [True, False]) 

134def test_on_train_start(default_classifier_instance, has_P: bool): 

135 classifier_instance, _ = default_classifier_instance 

136 tsne_instance = classifier_instance.tsne 

137 num_samples = tsne_instance.batch_size * 10 

138 dataset = MyDataset(num_samples, 15) 

139 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size) 

140 

141 trainer = L.Trainer(fast_dev_run=True) 

142 

143 if has_P: 

144 classifier_instance.P = torch.tensor(torch.nan) 

145 

146 with ( 

147 patch.object(ParametricTSNE, "_calculate_P") as mocked_calculate_P, 

148 patch.object( 

149 DimensionalityReduction, "training_step", autospec=True 

150 ) as mocked_training_step, 

151 patch.object(DimensionalityReduction, "on_train_epoch_start"), 

152 patch.object(DimensionalityReduction, "on_train_epoch_end"), 

153 ): 

154 mocked_calculate_P.return_value = torch.tensor(torch.nan) 

155 mocked_training_step.return_value = None 

156 

157 trainer.fit(classifier_instance, test_data) 

158 

159 if not has_P: 

160 assert mocked_calculate_P.call_count == 1 

161 else: 

162 assert mocked_calculate_P.call_count == 0 

163 

164 assert torch.allclose( 

165 classifier_instance.P, torch.tensor(torch.nan), equal_nan=True 

166 ) 

167 

168 

169@pytest.mark.parametrize("epochs", [1, 2, 3]) 

170@pytest.mark.parametrize("has_exaggeration_ended", [True, False]) 

171@pytest.mark.parametrize("exaggeration_epochs", [0, 1]) 

172def test_on_train_epoch_start( 

173 default_classifier_instance, 

174 epochs: int, 

175 has_exaggeration_ended: bool, 

176 exaggeration_epochs: int, 

177): 

178 classifier_instance, params = default_classifier_instance 

179 

180 tsne_instance = classifier_instance.tsne 

181 num_samples = tsne_instance.batch_size * 10 

182 dataset = MyDataset(num_samples, 15) 

183 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size) 

184 

185 trainer = L.Trainer(max_epochs=epochs, limit_train_batches=1) 

186 

187 input_P = torch.ones((num_samples, tsne_instance.batch_size)) 

188 classifier_instance.P = input_P 

189 

190 classifier_instance.has_exaggeration_ended = has_exaggeration_ended 

191 classifier_instance.exaggeration_epochs = exaggeration_epochs 

192 

193 with ( 

194 patch.object(DimensionalityReduction, "on_train_start"), 

195 patch.object( 

196 DimensionalityReduction, "training_step", autospec=True 

197 ) as mocked_training_step, 

198 patch.object(DimensionalityReduction, "on_train_epoch_end"), 

199 ): 

200 mocked_training_step.return_value = None 

201 

202 trainer.fit(classifier_instance, test_data) 

203 

204 if has_exaggeration_ended and exaggeration_epochs == 0: 

205 assert torch.allclose(classifier_instance.P_current, input_P) 

206 elif has_exaggeration_ended: 

207 assert torch.allclose( 

208 classifier_instance.P_current, 

209 input_P * params["default_tsne_params"]["early_exaggeration_value"], 

210 ) 

211 

212 if ( 

213 not has_exaggeration_ended 

214 and epochs <= exaggeration_epochs 

215 and exaggeration_epochs > 0 

216 ): 

217 assert torch.allclose( 

218 classifier_instance.P_current, 

219 input_P * params["default_tsne_params"]["early_exaggeration_value"], 

220 ) 

221 elif not has_exaggeration_ended: 

222 assert torch.allclose(classifier_instance.P_current, input_P) 

223 assert classifier_instance.has_exaggeration_ended is True 

224 

225 

226def test_training_step(default_classifier_instance): 

227 classifier_instance, params = default_classifier_instance 

228 

229 tsne_instance = classifier_instance.tsne 

230 num_samples = tsne_instance.batch_size * 10 

231 dataset = MyDataset(num_samples, 15) 

232 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size) 

233 

234 trainer = L.Trainer(fast_dev_run=True, accelerator="cpu") 

235 

236 input_P = torch.ones((num_samples, tsne_instance.batch_size)) 

237 classifier_instance.P = input_P 

238 

239 with patch.object(DimensionalityReduction, "on_train_start"): 

240 trainer.fit(classifier_instance, test_data) 

241 

242 

243@pytest.mark.parametrize("validation_dataloaders_count", [1, 2, 3]) 

244@pytest.mark.parametrize("has_val_P", [True, False]) 

245def test_on_validation_start( 

246 default_classifier_instance, has_val_P: bool, validation_dataloaders_count: int 

247): 

248 # TODO: Maybe another way to test this? Try to skip training step if possible. Maybe switch to non-zero tensors as well 

249 classifier_instance, _ = default_classifier_instance 

250 tsne_instance = classifier_instance.tsne 

251 num_samples = tsne_instance.batch_size * 10 

252 dataset = MyDataset(num_samples, 15) 

253 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size) 

254 test_val_data = [ 

255 DataLoaderMock(dataset, batch_size=tsne_instance.batch_size) 

256 for _ in range(validation_dataloaders_count) 

257 ] 

258 

259 trainer = L.Trainer(fast_dev_run=True, limit_train_batches=0) 

260 

261 if has_val_P: 

262 classifier_instance.val_P = [ 

263 torch.tensor(torch.nan) for _ in range(validation_dataloaders_count) 

264 ] 

265 

266 classifier_instance.P = torch.tensor(torch.nan) 

267 

268 with ( 

269 patch.object(ParametricTSNE, "_calculate_P") as mocked_calculate_P, 

270 patch.object( 

271 DimensionalityReduction, "validation_step", autospec=True 

272 ) as mocked_validation_step, 

273 patch.object( 

274 DimensionalityReduction, "training_step", autospec=True 

275 ) as mocked_training_step, 

276 patch.object(DimensionalityReduction, "on_train_epoch_start"), 

277 patch.object(DimensionalityReduction, "on_train_epoch_end"), 

278 ): 

279 mocked_calculate_P.return_value = torch.tensor(torch.nan) 

280 mocked_validation_step.return_value = None 

281 mocked_training_step.return_value = None 

282 

283 trainer.fit(classifier_instance, test_data, test_val_data) 

284 

285 if not has_val_P: 

286 assert mocked_calculate_P.call_count == validation_dataloaders_count 

287 else: 

288 assert mocked_calculate_P.call_count == 0 

289 

290 returned_val_P = [ 

291 torch.tensor(torch.nan) for _ in range(validation_dataloaders_count) 

292 ] 

293 for i in range(validation_dataloaders_count): 

294 assert torch.allclose( 

295 classifier_instance.val_P[i], returned_val_P[i], equal_nan=True 

296 ) 

297 

298 

299@pytest.mark.parametrize("validation_dataloaders_count", [1, 2, 3]) 

300def test_validation_step( 

301 default_classifier_instance, validation_dataloaders_count: int 

302): 

303 # TODO: Check in actual training 

304 classifier_instance, params = default_classifier_instance 

305 

306 tsne_instance = classifier_instance.tsne 

307 num_samples = tsne_instance.batch_size * 10 

308 dataset = MyDataset(num_samples, 15) 

309 test_data = DataLoaderMock(dataset, batch_size=tsne_instance.batch_size) 

310 test_val_data = [ 

311 DataLoaderMock(dataset, batch_size=tsne_instance.batch_size) 

312 for _ in range(validation_dataloaders_count) 

313 ] 

314 trainer = L.Trainer(fast_dev_run=True, accelerator="cpu") 

315 

316 input_P = torch.ones((num_samples, tsne_instance.batch_size)) 

317 input_val_P = [ 

318 torch.ones((num_samples, tsne_instance.batch_size)) 

319 for _ in range(validation_dataloaders_count) 

320 ] 

321 classifier_instance.P = input_P 

322 classifier_instance.val_P = input_val_P 

323 

324 with patch.object(DimensionalityReduction, "on_validation_start"): 

325 trainer.fit(classifier_instance, test_data, test_val_data) 

326 

327 

328@pytest.mark.parametrize( 

329 "optimizer, expected_instance", 

330 [ 

331 ("adam", torch.optim.Adam), 

332 ("sgd", torch.optim.SGD), 

333 ("rmsprop", torch.optim.RMSprop), 

334 ], 

335) 

336def test_configure_optimizers( 

337 default_classifier_instance, 

338 optimizer: str, 

339 expected_instance: torch.optim.Optimizer, 

340): 

341 classifier_instance, _ = default_classifier_instance 

342 classifier_instance.optimizer = optimizer 

343 

344 returned = classifier_instance.configure_optimizers() 

345 assert isinstance(returned, expected_instance) 

346 assert returned.param_groups[0]["lr"] == classifier_instance.lr 

347 

348 

349@pytest.mark.parametrize("optimizer", ["dummy_optimizer", "adom"]) 

350def test_configure_optimizers_invalid(default_classifier_instance, optimizer: str): 

351 classifier_instance, _ = default_classifier_instance 

352 classifier_instance.optimizer = optimizer 

353 

354 with pytest.raises(ValueError): 

355 classifier_instance.configure_optimizers()