123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import random
- from itertools import islice
- import numpy as np
- import torch
- class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
- def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
- if isinstance(next(iter(data_source)), dict):
- first_key = next(iter(next(iter(data_source)).keys()))
- self.lengths = [len(d[first_key]) for d in data_source]
- else:
- self.lengths = [len(d) for d in data_source]
- self.batch_size = batch_size
- self.drop_last = drop_last
- self.shuffle = shuffle
- def __iter__(self):
- ids = np.argsort(self.lengths)
- if self.drop_last:
- ids = ids[:len(ids) // self.batch_size * self.batch_size]
- batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
- if self.shuffle:
- random.shuffle(batches)
- for b in batches:
- yield b
- def __len__(self):
- if self.drop_last:
- return len(self.lengths) // self.batch_size
- else:
- return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
- class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
- def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
- random.seed(seed)
- self.batch_sampler = LengthBasedBatchSampler(
- data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
- )
- self.num_replicas = num_replicas
- self.rank = rank
-
- def __iter__(self):
- max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
- return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
-
- def __len__(self):
- return len(self.batch_sampler) // self.num_replicas
-
|