sampler.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import random
  4. from itertools import islice
  5. import numpy as np
  6. import torch
  7. class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
  8. def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
  9. if isinstance(next(iter(data_source)), dict):
  10. first_key = next(iter(next(iter(data_source)).keys()))
  11. self.lengths = [len(d[first_key]) for d in data_source]
  12. else:
  13. self.lengths = [len(d) for d in data_source]
  14. self.batch_size = batch_size
  15. self.drop_last = drop_last
  16. self.shuffle = shuffle
  17. def __iter__(self):
  18. ids = np.argsort(self.lengths)
  19. if self.drop_last:
  20. ids = ids[:len(ids) // self.batch_size * self.batch_size]
  21. batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
  22. if self.shuffle:
  23. random.shuffle(batches)
  24. for b in batches:
  25. yield b
  26. def __len__(self):
  27. if self.drop_last:
  28. return len(self.lengths) // self.batch_size
  29. else:
  30. return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
  31. class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
  32. def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
  33. random.seed(seed)
  34. self.batch_sampler = LengthBasedBatchSampler(
  35. data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
  36. )
  37. self.num_replicas = num_replicas
  38. self.rank = rank
  39. def __iter__(self):
  40. max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
  41. return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
  42. def __len__(self):
  43. return len(self.batch_sampler) // self.num_replicas