Browse Source

Adds length based batch sampler

Matthias Reso 1 year ago
parent
commit
f620f3589d

+ 36 - 1
src/llama_recipes/datasets/utils.py

@@ -1,10 +1,13 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import random
 from tqdm import tqdm
 from itertools import chain
 
-from torch.utils.data import Dataset
+import numpy as np
+from torch.utils.data import Dataset, BatchSampler
+
 
 class Concatenator(object):
     def __init__(self, chunk_size=2048):
@@ -64,3 +67,35 @@ class ConcatDataset(Dataset):
     
     def __len__(self):
         return len(self.samples)
+    
+
+class LengthBasedBatchSampler(BatchSampler):
+    def __init__(self, data_source, batch_size, drop_last, randomize=True):
+        if isinstance(next(iter(data_source)), dict):
+            first_key = next(iter(next(iter(data_source)).keys()))
+            self.lengths = [len(d[first_key]) for d in data_source]
+        else:
+            self.lengths = [len(d) for d in data_source]
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.randomize = randomize
+        
+    def __iter__(self):
+        ids = np.argsort(self.lengths)
+        if self.drop_last:
+            ids = ids[:len(ids) // self.batch_size * self.batch_size]
+            
+        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
+        
+        if self.randomize:
+            random.shuffle(batches)
+            
+        for b in batches:
+            yield b
+        
+    def __len__(self):
+        if self.drop_last:
+            return len(self.lengths) // self.batch_size 
+        else:
+            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
+        

+ 53 - 0
tests/datasets/test_length_based_batch_sampler.py

@@ -0,0 +1,53 @@
+import random
+import pytest
+
+import torch
+
+from llama_recipes.datasets.utils import LengthBasedBatchSampler
+
+SAMPLES = 33
+
+@pytest.fixture
+def dataset():
+    dataset = []
+    def add_samples(ds, n, a, b):
+        for _ in range(n):
+            ds.append(random.randint(a,b) * [1,])
+    add_samples(dataset, SAMPLES // 2, 1,9)
+    add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
+    
+    return random.sample(dataset, len(dataset))
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_array(dataset, batch_size, drop_last):
+    
+    sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_dict(dataset, batch_size, drop_last):
+    
+    dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
+    
+    sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)