Coverage for NeuralTSNE/NeuralTSNE/Utils/Writers/StatWriters/stat_writers.py: 100%
27 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from typing import Any, List, Union
3import torch
4from torch.utils.data import DataLoader
5from tqdm import tqdm
8def save_means_and_vars(data: torch.Tensor, filtered_data: torch.Tensor = None) -> None:
9 """
10 Calculate and save the means and variances of columns in a 2D `torch.Tensor` to a file.
12 If `filtered_data` is provided, it calculates and saves means and variances for both original and filtered columns.
14 Parameters
15 ----------
16 `data` : `torch.Tensor`
17 The input 2D tensor for which means and variances are calculated.
18 `filtered_data` : `torch.Tensor`, optional
19 A filtered version of the input data. Defaults to `None`.
21 Note
22 -----
23 - The function calculates means and variances for each column in the input data.
24 - If `filtered_data` is provided, it also calculates and saves means and variances for the corresponding filtered columns.
25 """
26 means = data.mean(axis=0)
27 variances = data.var(axis=0)
29 if filtered_data is not None:
30 filtered_means = filtered_data.mean(axis=0)
31 filtered_variances = filtered_data.var(axis=0)
33 with open("means_and_vars.txt", "w") as f:
34 f.writelines("column\tmean\tvar\n")
35 for v in range(len(means)):
36 f.writelines(f"{v}\t{means[v]}\t{variances[v]}\n")
37 if filtered_data is not None:
38 f.writelines("\nfiltered_column\tfiltered_mean\tfiltered_var\n")
39 for v in range(len(filtered_means)):
40 f.writelines(f"{v}\t{filtered_means[v]}\t{filtered_variances[v]}\n")
43def save_results(args: dict, test: DataLoader, Y: Union[List[Any], List[List[Any]]]):
44 """
45 Save results to a file.
47 Parameters
48 ----------
49 `args` : `dict`
50 Dictionary containing arguments, including the output file path (`o`) and step size (`step`).
51 `test` : `DataLoader`
52 DataLoader for the test dataset.
53 `Y` : `List[Any] | List[List[Any]]`
54 List of results to be saved.
56 Note
57 ----
58 This function saves the results to a file specified by the output file path in the arguments.
59 """
60 if test is not None:
61 with open(args["o"], "w") as f:
62 f.writelines(f"{args['step']}\n")
63 for _, batch in tqdm(
64 enumerate(Y), unit="batches", total=(len(Y)), desc="Saving results"
65 ):
66 for entry in batch:
67 processed_entry = [
68 (
69 x.item() if hasattr(x, "item") else x
70 ) # Use .item() if x is a scalar tensor
71 for x in entry
72 ]
73 output_line = "\t".join([str(x) for x in processed_entry])
74 f.writelines(f"{output_line}\n")