Coverage for NeuralTSNE/NeuralTSNE/TSNE/CostFunctions/cost_functions.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1from typing import Any 

2import torch 

3 

4 

5class CostFunctions: 

6 """ 

7 Class containing cost functions for t-SNE. 

8 """ 

9 

10 def __new__(cls, name): 

11 """ 

12 Returns the specified cost function by name. 

13 

14 Parameters 

15 ---------- 

16 `name` : `str` 

17 The name of the cost function to retrieve. 

18 

19 Returns 

20 ------- 

21 `callable` 

22 The specified cost function. 

23 """ 

24 return getattr(CostFunctions, name) 

25 

26 @staticmethod 

27 def kl_divergence( 

28 Y: torch.Tensor, P: torch.Tensor, params: dict[str, Any] 

29 ) -> torch.Tensor: 

30 """ 

31 Calculates the Kullback-Leibler divergence. 

32 

33 Parameters 

34 ---------- 

35 `Y` : `torch.Tensor` 

36 Embedding tensor. 

37 `P` : `torch.Tensor` 

38 Conditional probability matrix. 

39 

40 Returns 

41 ------- 

42 `torch.Tensor` 

43 Kullback-Leibler divergence. 

44 

45 Note 

46 ---- 

47 Calculates the Kullback-Leibler divergence between the true conditional probability matrix P 

48 and the conditional probability matrix Q based on the current embedding Y. 

49 """ 

50 sum_Y = torch.sum(torch.square(Y), dim=1) 

51 eps = torch.tensor([1e-15], device=params["device"]) 

52 D = sum_Y + torch.reshape(sum_Y, [-1, 1]) - 2 * torch.matmul(Y, Y.mT) 

53 Q = torch.pow(1 + D / 1.0, -(1.0 + 1) / 2) 

54 Q *= 1 - torch.eye(params["batch_size"], device=params["device"]) 

55 Q /= torch.sum(Q) 

56 Q = torch.maximum(Q, eps) 

57 C = torch.log((P + eps) / (Q + eps)) 

58 C = torch.sum(P * C) 

59 return C