| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- import math
- import numpy as np
- from torch.utils.data.sampler import Sampler
- class DistributedTestSampler(Sampler):
- def __init__(self, dataset, world_size, rank, validation=False):
- num_total = len(dataset)
- part = math.ceil(num_total / world_size)
- if rank == world_size - 1:
- self.num_samples = num_total - (part * (world_size - 1))
- self.indices = range(part * (world_size - 1), num_total)
- if validation:
- self.indices = list(self.indices) + list(range(part-self.num_samples))
- else:
- self.num_samples = part
- self.indices = range(part * rank, part * (rank + 1))
- def __len__(self):
- return self.num_samples
- def __iter__(self):
- return iter(self.indices)
- class DistributedGivenIterationSampler(Sampler):
- def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1):
- self.dataset = dataset
- self.total_iter = total_iter
- self.batch_size = batch_size
- self.world_size = world_size
- self.rank = rank
- self.last_iter = last_iter
- self.total_size = self.total_iter*self.batch_size
- self.indices = self.gen_new_list()
- self.call = 0
- def __iter__(self):
- if self.call == 0:
- self.call = 1
- return iter(self.indices[(self.last_iter+1)*self.batch_size:])
- else:
- raise RuntimeError("this sampler is not designed to be called more than once!!")
- def gen_new_list(self):
- # each process shuffle all list with same seed, and pick one piece according to rank
- np.random.seed(0)
- all_size = self.total_size * self.world_size
- indices = np.arange(len(self.dataset))
- indices = indices[:all_size]
- num_repeat = (all_size-1) // indices.shape[0] + 1
- indices = np.tile(indices, num_repeat)
- indices = indices[:all_size]
- np.random.shuffle(indices)
- beg = self.total_size * self.rank
- indices = indices[beg:beg+self.total_size]
- assert len(indices) == self.total_size
- return indices
- def __len__(self):
- # note here we do not take last iter into consideration, since __len__
- # should only be used for displaying, the correct remaining size is
- # handled by dataloader
- #return self.total_size - (self.last_iter+1)*self.batch_size
- return self.total_size
|