Coverage for NeuralTSNE/NeuralTSNE/TSNE/tests/common.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-18 16:32 +0000

1import io 

2from typing import Tuple 

3 

4import torch 

5from torch.utils.data import Dataset 

6 

7 

8class PersistentStringIO(io.StringIO): 

9 def __init__(self, *args, **kwargs): 

10 super().__init__(*args, **kwargs) 

11 self._closed = False 

12 

13 def close(self): 

14 self._closed = True 

15 

16 @property 

17 def closed(self): 

18 return self._closed 

19 

20 

21class MyDataset(Dataset): 

22 def __init__( 

23 self, 

24 num_samples: int, 

25 num_variables: int, 

26 item_range: Tuple[float, float] | Tuple[int, int] = None, 

27 generate_int: bool = False, 

28 ): 

29 self.num_samples = num_samples 

30 self.num_variables = num_variables 

31 self.item_range = item_range or (0, 1) 

32 self.generate_int = generate_int 

33 

34 def __len__(self): 

35 return self.num_samples 

36 

37 def __getitem__(self, index): 

38 if self.generate_int: 

39 sample = torch.randint(*self.item_range, size=(self.num_variables,)) 

40 else: 

41 sample = torch.FloatTensor(self.num_variables).uniform_(*self.item_range) 

42 return tuple([sample]) 

43 

44 

45class DataLoaderMock: 

46 def __init__(self, dataset: MyDataset, batch_size: int): 

47 self.dataset = dataset 

48 self.batch_size = batch_size 

49 self.batches = [] 

50 

51 def __iter__(self): 

52 for i in range(0, len(self.dataset), self.batch_size): 

53 batch = tuple( 

54 torch.cat( 

55 [ 

56 torch.unsqueeze(self.dataset[j][k], 0) 

57 for j in range(i, i + self.batch_size) 

58 ], 

59 dim=0, 

60 ) 

61 for k in range(len(self.dataset[0])) 

62 ) 

63 self.batches.append(batch) 

64 yield batch 

65 

66 def __len__(self): 

67 return len(self.dataset)