Source code for NeuralTSNE.Utils.Writers.StatWriters.stat_writers

from typing import Any, List, Union

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm


[docs] def save_means_and_vars(data: torch.Tensor, filtered_data: torch.Tensor = None) -> None: """ Calculate and save the means and variances of columns in a 2D `torch.Tensor` to a file. If `filtered_data` is provided, it calculates and saves means and variances for both original and filtered columns. Parameters ---------- `data` : `torch.Tensor` The input 2D tensor for which means and variances are calculated. `filtered_data` : `torch.Tensor`, optional A filtered version of the input data. Defaults to `None`. Note ----- - The function calculates means and variances for each column in the input data. - If `filtered_data` is provided, it also calculates and saves means and variances for the corresponding filtered columns. """ means = data.mean(axis=0) variances = data.var(axis=0) if filtered_data is not None: filtered_means = filtered_data.mean(axis=0) filtered_variances = filtered_data.var(axis=0) with open("means_and_vars.txt", "w") as f: f.writelines("column\tmean\tvar\n") for v in range(len(means)): f.writelines(f"{v}\t{means[v]}\t{variances[v]}\n") if filtered_data is not None: f.writelines("\nfiltered_column\tfiltered_mean\tfiltered_var\n") for v in range(len(filtered_means)): f.writelines(f"{v}\t{filtered_means[v]}\t{filtered_variances[v]}\n")
[docs] def save_results(args: dict, test: DataLoader, Y: Union[List[Any], List[List[Any]]]): """ Save results to a file. Parameters ---------- `args` : `dict` Dictionary containing arguments, including the output file path (`o`) and step size (`step`). `test` : `DataLoader` DataLoader for the test dataset. `Y` : `List[Any] | List[List[Any]]` List of results to be saved. Note ---- This function saves the results to a file specified by the output file path in the arguments. """ if test is not None: with open(args["o"], "w") as f: f.writelines(f"{args['step']}\n") for _, batch in tqdm( enumerate(Y), unit="batches", total=(len(Y)), desc="Saving results" ): for entry in batch: processed_entry = [ ( x.item() if hasattr(x, "item") else x ) # Use .item() if x is a scalar tensor for x in entry ] output_line = "\t".join([str(x) for x in processed_entry]) f.writelines(f"{output_line}\n")