Coverage for NeuralTSNE/NeuralTSNE/TSNE/tests/test_neural_network.py: 96%

49 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1import pytest 

2import torch 

3from collections import OrderedDict 

4from typing import List 

5 

6from NeuralTSNE.TSNE.tests.fixtures.neural_network_fixtures import ( 

7 neural_network_params, 

8 neural_network, 

9) 

10 

11 

12@pytest.mark.parametrize( 

13 "neural_network_params", 

14 [ 

15 ((10, 5, [2, 3, 1, 0.5])), 

16 ((5, 3, [2, 6, 1])), 

17 ((3, 2, [2, 1])), 

18 ], 

19 indirect=True, 

20) 

21def test_neural_network_forward(neural_network_params, neural_network): 

22 input_data = torch.randn(10, neural_network_params["initial_features"]) 

23 output = neural_network(input_data) 

24 

25 assert output.shape == (10, neural_network_params["n_components"]) 

26 

27 

28@pytest.mark.parametrize( 

29 "neural_network_params, activation_functions", 

30 [ 

31 ( 

32 (10, 2, [2, 3, 4, 2, 0.3, 1.4]), 

33 [(i, torch.nn.ReLU) for i in range(1, 13, 2)], 

34 ), 

35 ((20, 3, [2, 6, 1]), [(i, torch.nn.ReLU) for i in range(1, 7, 2)]), 

36 ((15, 4, [2, 1]), [(i, torch.nn.ReLU) for i in range(1, 5, 2)]), 

37 ], 

38 indirect=["neural_network_params"], 

39) 

40def test_neural_network_layer_shapes( 

41 neural_network_params, neural_network, activation_functions: List[torch.nn.Module] 

42): 

43 input_data = torch.randn(10, neural_network_params["initial_features"]) 

44 layer_shapes = [input_data.shape[1]] 

45 is_activation_valid = [] 

46 

47 for i, layer in enumerate(neural_network.sequential_stack): 

48 if len(activation_functions) > 0 and i == activation_functions[0][0]: 

49 _, function = activation_functions.pop(0) 

50 if isinstance(layer, function): 50 ↛ 53line 50 didn't jump to line 53 because the condition on line 50 was always true

51 is_activation_valid.append(True) 

52 else: 

53 is_activation_valid.append(False) 

54 else: 

55 layer_shapes.append(layer.out_features) 

56 

57 expected_shapes = [ 

58 int( 

59 neural_network_params["multipliers"][i] 

60 * neural_network_params["initial_features"] 

61 ) 

62 for i in range(len(neural_network_params["multipliers"])) 

63 ] 

64 

65 expected_shapes = [layer_shapes[0]] + expected_shapes + [layer_shapes[-1]] 

66 assert expected_shapes == layer_shapes 

67 assert activation_functions == [] 

68 assert all(is_activation_valid) 

69 

70 

71@pytest.mark.parametrize( 

72 "neural_network_params", 

73 [ 

74 ( 

75 ( 

76 10, 

77 5, 

78 [2, 3, 1, 0.5], 

79 OrderedDict( 

80 { 

81 "0": torch.nn.Linear(10, 20), 

82 "ReLu0": torch.nn.ReLU(), 

83 "1": torch.nn.Linear(20, 30), 

84 "GeLu1": torch.nn.GELU(), 

85 "2": torch.nn.Linear(30, 5), 

86 "ELu2": torch.nn.ELU(), 

87 "3": torch.nn.Linear(5, 5), 

88 } 

89 ), 

90 ) 

91 ) 

92 ], 

93 indirect=True, 

94) 

95def test_neural_network_pre_filled_layers(neural_network_params, neural_network): 

96 pre_filled_layers = neural_network_params["pre_filled_layers"] 

97 neural_network_pre_filled = torch.nn.Sequential(pre_filled_layers) 

98 

99 for layer, pre_filled_layer in zip( 

100 neural_network.sequential_stack, neural_network_pre_filled 

101 ): 

102 assert layer == pre_filled_layer 

103 

104 

105@pytest.mark.parametrize( 

106 "neural_network_params", [((10, 5, [2, 3, 1, 0.5]))], indirect=True 

107) 

108def test_neural_network_gradients(neural_network_params, neural_network): 

109 input_data = torch.randn(10, neural_network_params["initial_features"]) 

110 target = torch.rand(10, neural_network_params["n_components"]) 

111 

112 loss_fn = torch.nn.MSELoss(reduction="sum") 

113 optimizer = torch.optim.SGD(neural_network.parameters(), lr=1e-1) 

114 

115 neural_network.train() 

116 

117 y_pred = neural_network(input_data) 

118 loss = loss_fn(y_pred, target) 

119 optimizer.zero_grad() 

120 

121 gradient = neural_network.sequential_stack[0].weight 

122 gradient = gradient.clone().detach() 

123 

124 loss.backward() 

125 

126 optimizer.step() 

127 

128 gradient_after = neural_network.sequential_stack[0].weight 

129 

130 assert not torch.allclose(gradient, gradient_after, atol=1e-8)