Coverage for NeuralTSNE/NeuralTSNE/Utils/Writers/LabelWriters/label_writers.py: 100%
18 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-18 16:32 +0000
1from torch.utils.data import DataLoader, Dataset
2from tqdm import tqdm
5def save_torch_labels(output: str, test: Dataset) -> None:
6 """
7 Save labels from a `torch.Dataset` to a text file.
9 The function extracts labels from the provided `test` dataset and saves them to a text file.
10 The output file is named based on the provided `output` parameter.
12 Parameters
13 ----------
14 `output` : `str`
15 The output file path for saving labels.
16 `test` : `Dataset`
17 The `torch.Dataset` containing labels to be saved.
19 Note
20 ----
21 - The function iterates through the `test` dataset, extracts labels, and saves them to a text file.
22 - The output file is named by appending `"_labels.txt"` to the `output` parameter, removing the file extension if present.
23 """
24 with open(
25 output.rsplit(".", maxsplit=1)[0] + "_labels.txt",
26 "w",
27 ) as f:
28 for _, row in tqdm(
29 enumerate(test), unit="samples", total=len(test), desc="Saving labels"
30 ):
31 f.writelines(f"{row[1]}\n")
34def save_labels_data(
35 args: dict,
36 test: DataLoader,
37) -> None:
38 """
39 Save labels data to a new file.
41 Parameters
42 ----------
43 `args` : `dict`
44 Dictionary containing arguments, including the output file path (`o`).
45 `test` : `DataLoader`
46 DataLoader for the test dataset.
48 Note
49 ----
50 This function saves the labels data to a new file with a name based on the original output file path.
51 """
52 if test is not None:
53 new_name = args["o"].rsplit(".", 1)[0] + "_labels.txt"
54 with open(new_name, "w") as f:
55 for _, batch in tqdm(
56 enumerate(test),
57 unit="batches",
58 total=(len(test)),
59 desc="Saving new labels",
60 ):
61 for samples in batch:
62 samples = samples.tolist()
63 for sample in samples:
64 for col in sample:
65 f.write(str(col))
66 f.write("\t")
67 f.write("\n")