Coverage for NeuralTSNE/NeuralTSNE/TSNE/CostFunctions/cost_functions.py: 100%
17 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from typing import Any
2import torch
5class CostFunctions:
6 """
7 Class containing cost functions for t-SNE.
8 """
10 def __new__(cls, name):
11 """
12 Returns the specified cost function by name.
14 Parameters
15 ----------
16 `name` : `str`
17 The name of the cost function to retrieve.
19 Returns
20 -------
21 `callable`
22 The specified cost function.
23 """
24 return getattr(CostFunctions, name)
26 @staticmethod
27 def kl_divergence(
28 Y: torch.Tensor, P: torch.Tensor, params: dict[str, Any]
29 ) -> torch.Tensor:
30 """
31 Calculates the Kullback-Leibler divergence.
33 Parameters
34 ----------
35 `Y` : `torch.Tensor`
36 Embedding tensor.
37 `P` : `torch.Tensor`
38 Conditional probability matrix.
40 Returns
41 -------
42 `torch.Tensor`
43 Kullback-Leibler divergence.
45 Note
46 ----
47 Calculates the Kullback-Leibler divergence between the true conditional probability matrix P
48 and the conditional probability matrix Q based on the current embedding Y.
49 """
50 sum_Y = torch.sum(torch.square(Y), dim=1)
51 eps = torch.tensor([1e-15], device=params["device"])
52 D = sum_Y + torch.reshape(sum_Y, [-1, 1]) - 2 * torch.matmul(Y, Y.mT)
53 Q = torch.pow(1 + D / 1.0, -(1.0 + 1) / 2)
54 Q *= 1 - torch.eye(params["batch_size"], device=params["device"])
55 Q /= torch.sum(Q)
56 Q = torch.maximum(Q, eps)
57 C = torch.log((P + eps) / (Q + eps))
58 C = torch.sum(P * C)
59 return C