Coverage for NeuralTSNE/NeuralTSNE/Utils/Writers/LabelWriters/label_writers.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1from torch.utils.data import DataLoader, Dataset 

2from tqdm import tqdm 

3 

4 

5def save_torch_labels(output: str, test: Dataset) -> None: 

6 """ 

7 Save labels from a `torch.Dataset` to a text file. 

8 

9 The function extracts labels from the provided `test` dataset and saves them to a text file. 

10 The output file is named based on the provided `output` parameter. 

11 

12 Parameters 

13 ---------- 

14 `output` : `str` 

15 The output file path for saving labels. 

16 `test` : `Dataset` 

17 The `torch.Dataset` containing labels to be saved. 

18 

19 Note 

20 ---- 

21 - The function iterates through the `test` dataset, extracts labels, and saves them to a text file. 

22 - The output file is named by appending `"_labels.txt"` to the `output` parameter, removing the file extension if present. 

23 """ 

24 with open( 

25 output.rsplit(".", maxsplit=1)[0] + "_labels.txt", 

26 "w", 

27 ) as f: 

28 for _, row in tqdm( 

29 enumerate(test), unit="samples", total=len(test), desc="Saving labels" 

30 ): 

31 f.writelines(f"{row[1]}\n") 

32 

33 

34def save_labels_data( 

35 args: dict, 

36 test: DataLoader, 

37) -> None: 

38 """ 

39 Save labels data to a new file. 

40 

41 Parameters 

42 ---------- 

43 `args` : `dict` 

44 Dictionary containing arguments, including the output file path (`o`). 

45 `test` : `DataLoader` 

46 DataLoader for the test dataset. 

47 

48 Note 

49 ---- 

50 This function saves the labels data to a new file with a name based on the original output file path. 

51 """ 

52 if test is not None: 

53 new_name = args["o"].rsplit(".", 1)[0] + "_labels.txt" 

54 with open(new_name, "w") as f: 

55 for _, batch in tqdm( 

56 enumerate(test), 

57 unit="batches", 

58 total=(len(test)), 

59 desc="Saving new labels", 

60 ): 

61 for samples in batch: 

62 samples = samples.tolist() 

63 for sample in samples: 

64 for col in sample: 

65 f.write(str(col)) 

66 f.write("\t") 

67 f.write("\n")