sampler.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import math
  2. import numpy as np
  3. from torch.utils.data.sampler import Sampler
  4. class DistributedTestSampler(Sampler):
  5. def __init__(self, dataset, world_size, rank, validation=False):
  6. num_total = len(dataset)
  7. part = math.ceil(num_total / world_size)
  8. if rank == world_size - 1:
  9. self.num_samples = num_total - (part * (world_size - 1))
  10. self.indices = range(part * (world_size - 1), num_total)
  11. if validation:
  12. self.indices = list(self.indices) + list(range(part-self.num_samples))
  13. else:
  14. self.num_samples = part
  15. self.indices = range(part * rank, part * (rank + 1))
  16. def __len__(self):
  17. return self.num_samples
  18. def __iter__(self):
  19. return iter(self.indices)
  20. class DistributedGivenIterationSampler(Sampler):
  21. def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1):
  22. self.dataset = dataset
  23. self.total_iter = total_iter
  24. self.batch_size = batch_size
  25. self.world_size = world_size
  26. self.rank = rank
  27. self.last_iter = last_iter
  28. self.total_size = self.total_iter*self.batch_size
  29. self.indices = self.gen_new_list()
  30. self.call = 0
  31. def __iter__(self):
  32. if self.call == 0:
  33. self.call = 1
  34. return iter(self.indices[(self.last_iter+1)*self.batch_size:])
  35. else:
  36. raise RuntimeError("this sampler is not designed to be called more than once!!")
  37. def gen_new_list(self):
  38. # each process shuffle all list with same seed, and pick one piece according to rank
  39. np.random.seed(0)
  40. all_size = self.total_size * self.world_size
  41. indices = np.arange(len(self.dataset))
  42. indices = indices[:all_size]
  43. num_repeat = (all_size-1) // indices.shape[0] + 1
  44. indices = np.tile(indices, num_repeat)
  45. indices = indices[:all_size]
  46. np.random.shuffle(indices)
  47. beg = self.total_size * self.rank
  48. indices = indices[beg:beg+self.total_size]
  49. assert len(indices) == self.total_size
  50. return indices
  51. def __len__(self):
  52. # note here we do not take last iter into consideration, since __len__
  53. # should only be used for displaying, the correct remaining size is
  54. # handled by dataloader
  55. #return self.total_size - (self.last_iter+1)*self.batch_size
  56. return self.total_size