Coverage for NeuralTSNE/NeuralTSNE/TSNE/Helpers/helpers.py: 82%
41 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from typing import Tuple
2import torch
5def Hbeta(D: torch.Tensor, beta: float) -> Tuple[torch.Tensor, torch.Tensor]:
6 """
7 Calculates entropy and probability distribution based on a distance matrix.
9 Parameters
10 ----------
11 `D` : `torch.Tensor`
12 Distance matrix.
13 `beta` : `float`
14 Parameter for the computation.
16 Returns
17 -------
18 `Tuple[torch.Tensor, torch.Tensor]`
19 Entropy and probability distribution.
21 Note
22 ----
23 The function calculates the entropy and probability distribution based on
24 the provided distance matrix (`D`) and the specified parameter (`beta`).
25 """
26 P = torch.exp(-D * beta)
27 sumP = torch.sum(P)
28 H = torch.log(sumP) + beta * torch.sum(D * P) / sumP
29 P = P / sumP
30 return H, P
33def x2p_job(
34 data: Tuple[int, torch.Tensor, torch.Tensor],
35 tolerance: float,
36 max_iterations: int = 50,
37) -> Tuple[int, torch.Tensor, torch.Tensor, int]:
38 """
39 Performs a binary search to find an appropriate value of `beta` for a given point.
41 Parameters
42 ----------
43 `data` : `Tuple[int, torch.Tensor, torch.Tensor]`
44 Tuple containing index, distance matrix, and target entropy.
45 `tolerance` : `float`
46 Tolerance level for convergence.
47 `max_iterations` : `int`, optional
48 Maximum number of iterations for the binary search. Defaults to `50`.
50 Returns
51 -------
52 `Tuple[int, torch.Tensor, torch.Tensor, int]`
53 Index, probability distribution, entropy difference, and number of iterations.
55 Note
56 ----
57 The function performs a binary search to find an appropriate value of `beta` for a given point,
58 aiming to match the target entropy.
59 """
60 i, Di, logU = data
61 beta = 1.0
62 beta_min = -torch.inf
63 beta_max = torch.inf
65 H, thisP = Hbeta(Di, beta)
66 Hdiff = H - logU
68 it = 0
69 while it < max_iterations and torch.abs(Hdiff) > tolerance:
70 if Hdiff > 0: 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true
71 beta_min = beta
72 if torch.isinf(torch.tensor(beta_max)):
73 beta *= 2
74 else:
75 beta = (beta + beta_max) / 2
76 else:
77 beta_max = beta
78 if torch.isinf(torch.tensor(beta_min)): 78 ↛ 81line 78 didn't jump to line 81 because the condition on line 78 was always true
79 beta /= 2
80 else:
81 beta = (beta + beta_min) / 2
83 H, thisP = Hbeta(Di, beta)
84 Hdiff = H - logU
85 it += 1
86 return i, thisP, Hdiff, it
89def x2p(
90 X: torch.Tensor,
91 perplexity: int,
92 tolerance: float,
93) -> torch.Tensor:
94 """
95 Compute conditional probabilities.
97 Parameters
98 ----------
99 `X` : `torch.Tensor`
100 Input data tensor.
101 `perplexity` : `int`
102 Perplexity parameter for t-SNE.
103 `tolerance` : `float`
104 Tolerance level for convergence.
106 Returns
107 -------
108 `torch.Tensor`
109 Conditional probability matrix.
110 """
111 n = X.shape[0]
112 logU = torch.log(torch.tensor([perplexity], device=X.device))
114 sum_X = torch.sum(torch.square(X), dim=1)
115 D = torch.add(torch.add(-2 * torch.mm(X, X.mT), sum_X).T, sum_X)
117 idx = (1 - torch.eye(n)).type(torch.bool)
118 D = D[idx].reshape((n, -1))
120 P = torch.zeros(n, n, device=X.device)
122 for i in range(n):
123 P[i, idx[i]] = x2p_job((i, D[i], logU), tolerance)[1]
124 return P