Coverage for NeuralTSNE/NeuralTSNE/TSNE/Helpers/helpers.py: 82%

41 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1from typing import Tuple 

2import torch 

3 

4 

5def Hbeta(D: torch.Tensor, beta: float) -> Tuple[torch.Tensor, torch.Tensor]: 

6 """ 

7 Calculates entropy and probability distribution based on a distance matrix. 

8 

9 Parameters 

10 ---------- 

11 `D` : `torch.Tensor` 

12 Distance matrix. 

13 `beta` : `float` 

14 Parameter for the computation. 

15 

16 Returns 

17 ------- 

18 `Tuple[torch.Tensor, torch.Tensor]` 

19 Entropy and probability distribution. 

20 

21 Note 

22 ---- 

23 The function calculates the entropy and probability distribution based on 

24 the provided distance matrix (`D`) and the specified parameter (`beta`). 

25 """ 

26 P = torch.exp(-D * beta) 

27 sumP = torch.sum(P) 

28 H = torch.log(sumP) + beta * torch.sum(D * P) / sumP 

29 P = P / sumP 

30 return H, P 

31 

32 

33def x2p_job( 

34 data: Tuple[int, torch.Tensor, torch.Tensor], 

35 tolerance: float, 

36 max_iterations: int = 50, 

37) -> Tuple[int, torch.Tensor, torch.Tensor, int]: 

38 """ 

39 Performs a binary search to find an appropriate value of `beta` for a given point. 

40 

41 Parameters 

42 ---------- 

43 `data` : `Tuple[int, torch.Tensor, torch.Tensor]` 

44 Tuple containing index, distance matrix, and target entropy. 

45 `tolerance` : `float` 

46 Tolerance level for convergence. 

47 `max_iterations` : `int`, optional 

48 Maximum number of iterations for the binary search. Defaults to `50`. 

49 

50 Returns 

51 ------- 

52 `Tuple[int, torch.Tensor, torch.Tensor, int]` 

53 Index, probability distribution, entropy difference, and number of iterations. 

54 

55 Note 

56 ---- 

57 The function performs a binary search to find an appropriate value of `beta` for a given point, 

58 aiming to match the target entropy. 

59 """ 

60 i, Di, logU = data 

61 beta = 1.0 

62 beta_min = -torch.inf 

63 beta_max = torch.inf 

64 

65 H, thisP = Hbeta(Di, beta) 

66 Hdiff = H - logU 

67 

68 it = 0 

69 while it < max_iterations and torch.abs(Hdiff) > tolerance: 

70 if Hdiff > 0: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true

71 beta_min = beta 

72 if torch.isinf(torch.tensor(beta_max)): 

73 beta *= 2 

74 else: 

75 beta = (beta + beta_max) / 2 

76 else: 

77 beta_max = beta 

78 if torch.isinf(torch.tensor(beta_min)): 78 ↛ 81line 78 didn't jump to line 81 because the condition on line 78 was always true

79 beta /= 2 

80 else: 

81 beta = (beta + beta_min) / 2 

82 

83 H, thisP = Hbeta(Di, beta) 

84 Hdiff = H - logU 

85 it += 1 

86 return i, thisP, Hdiff, it 

87 

88 

89def x2p( 

90 X: torch.Tensor, 

91 perplexity: int, 

92 tolerance: float, 

93) -> torch.Tensor: 

94 """ 

95 Compute conditional probabilities. 

96 

97 Parameters 

98 ---------- 

99 `X` : `torch.Tensor` 

100 Input data tensor. 

101 `perplexity` : `int` 

102 Perplexity parameter for t-SNE. 

103 `tolerance` : `float` 

104 Tolerance level for convergence. 

105 

106 Returns 

107 ------- 

108 `torch.Tensor` 

109 Conditional probability matrix. 

110 """ 

111 n = X.shape[0] 

112 logU = torch.log(torch.tensor([perplexity], device=X.device)) 

113 

114 sum_X = torch.sum(torch.square(X), dim=1) 

115 D = torch.add(torch.add(-2 * torch.mm(X, X.mT), sum_X).T, sum_X) 

116 

117 idx = (1 - torch.eye(n)).type(torch.bool) 

118 D = D[idx].reshape((n, -1)) 

119 

120 P = torch.zeros(n, n, device=X.device) 

121 

122 for i in range(n): 

123 P[i, idx[i]] = x2p_job((i, D[i], logU), tolerance)[1] 

124 return P