Coverage for NeuralTSNE/NeuralTSNE/Utils/Writers/StatWriters/stat_writers.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1from typing import Any, List, Union 

2 

3import torch 

4from torch.utils.data import DataLoader 

5from tqdm import tqdm 

6 

7 

8def save_means_and_vars(data: torch.Tensor, filtered_data: torch.Tensor = None) -> None: 

9 """ 

10 Calculate and save the means and variances of columns in a 2D `torch.Tensor` to a file. 

11 

12 If `filtered_data` is provided, it calculates and saves means and variances for both original and filtered columns. 

13 

14 Parameters 

15 ---------- 

16 `data` : `torch.Tensor` 

17 The input 2D tensor for which means and variances are calculated. 

18 `filtered_data` : `torch.Tensor`, optional 

19 A filtered version of the input data. Defaults to `None`. 

20 

21 Note 

22 ----- 

23 - The function calculates means and variances for each column in the input data. 

24 - If `filtered_data` is provided, it also calculates and saves means and variances for the corresponding filtered columns. 

25 """ 

26 means = data.mean(axis=0) 

27 variances = data.var(axis=0) 

28 

29 if filtered_data is not None: 

30 filtered_means = filtered_data.mean(axis=0) 

31 filtered_variances = filtered_data.var(axis=0) 

32 

33 with open("means_and_vars.txt", "w") as f: 

34 f.writelines("column\tmean\tvar\n") 

35 for v in range(len(means)): 

36 f.writelines(f"{v}\t{means[v]}\t{variances[v]}\n") 

37 if filtered_data is not None: 

38 f.writelines("\nfiltered_column\tfiltered_mean\tfiltered_var\n") 

39 for v in range(len(filtered_means)): 

40 f.writelines(f"{v}\t{filtered_means[v]}\t{filtered_variances[v]}\n") 

41 

42 

43def save_results(args: dict, test: DataLoader, Y: Union[List[Any], List[List[Any]]]): 

44 """ 

45 Save results to a file. 

46 

47 Parameters 

48 ---------- 

49 `args` : `dict` 

50 Dictionary containing arguments, including the output file path (`o`) and step size (`step`). 

51 `test` : `DataLoader` 

52 DataLoader for the test dataset. 

53 `Y` : `List[Any] | List[List[Any]]` 

54 List of results to be saved. 

55 

56 Note 

57 ---- 

58 This function saves the results to a file specified by the output file path in the arguments. 

59 """ 

60 if test is not None: 

61 with open(args["o"], "w") as f: 

62 f.writelines(f"{args['step']}\n") 

63 for _, batch in tqdm( 

64 enumerate(Y), unit="batches", total=(len(Y)), desc="Saving results" 

65 ): 

66 for entry in batch: 

67 processed_entry = [ 

68 ( 

69 x.item() if hasattr(x, "item") else x 

70 ) # Use .item() if x is a scalar tensor 

71 for x in entry 

72 ] 

73 output_line = "\t".join([str(x) for x in processed_entry]) 

74 f.writelines(f"{output_line}\n")