|
@@ -1,12 +1,10 @@
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
-import random
|
|
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
from itertools import chain
|
|
from itertools import chain
|
|
|
|
|
|
-import numpy as np
|
|
|
|
-from torch.utils.data import Dataset, BatchSampler
|
|
|
|
|
|
+from torch.utils.data import Dataset
|
|
|
|
|
|
|
|
|
|
class Concatenator(object):
|
|
class Concatenator(object):
|
|
@@ -69,33 +67,3 @@ class ConcatDataset(Dataset):
|
|
return len(self.samples)
|
|
return len(self.samples)
|
|
|
|
|
|
|
|
|
|
-class LengthBasedBatchSampler(BatchSampler):
|
|
|
|
- def __init__(self, data_source, batch_size, drop_last, randomize=True):
|
|
|
|
- if isinstance(next(iter(data_source)), dict):
|
|
|
|
- first_key = next(iter(next(iter(data_source)).keys()))
|
|
|
|
- self.lengths = [len(d[first_key]) for d in data_source]
|
|
|
|
- else:
|
|
|
|
- self.lengths = [len(d) for d in data_source]
|
|
|
|
- self.batch_size = batch_size
|
|
|
|
- self.drop_last = drop_last
|
|
|
|
- self.randomize = randomize
|
|
|
|
-
|
|
|
|
- def __iter__(self):
|
|
|
|
- ids = np.argsort(self.lengths)
|
|
|
|
- if self.drop_last:
|
|
|
|
- ids = ids[:len(ids) // self.batch_size * self.batch_size]
|
|
|
|
-
|
|
|
|
- batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
|
|
|
|
-
|
|
|
|
- if self.randomize:
|
|
|
|
- random.shuffle(batches)
|
|
|
|
-
|
|
|
|
- for b in batches:
|
|
|
|
- yield b
|
|
|
|
-
|
|
|
|
- def __len__(self):
|
|
|
|
- if self.drop_last:
|
|
|
|
- return len(self.lengths) // self.batch_size
|
|
|
|
- else:
|
|
|
|
- return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
|
|
|
|
-
|
|
|