Browse Source

Added dist length based batch sampler

Matthias Reso 1 year ago
parent
commit
ddf58d205d

+ 22 - 3
src/llama_recipes/data/sampler.py

@@ -2,13 +2,14 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import random
+from itertools import islice
 
 import numpy as np
 import torch
 
 
 class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
-    def __init__(self, data_source, batch_size, drop_last, randomize=True):
+    def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
         if isinstance(next(iter(data_source)), dict):
             first_key = next(iter(next(iter(data_source)).keys()))
             self.lengths = [len(d[first_key]) for d in data_source]
@@ -16,7 +17,7 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
             self.lengths = [len(d) for d in data_source]
         self.batch_size = batch_size
         self.drop_last = drop_last
-        self.randomize = randomize
+        self.shuffle = shuffle
 
     def __iter__(self):
         ids = np.argsort(self.lengths)
@@ -25,7 +26,7 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
 
         batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
 
-        if self.randomize:
+        if self.shuffle:
             random.shuffle(batches)
 
         for b in batches:
@@ -36,3 +37,21 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
             return len(self.lengths) // self.batch_size
         else:
             return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
+
+
+class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
+        random.seed(seed)
+        self.batch_sampler = LengthBasedBatchSampler(
+            data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
+            )
+        self.num_replicas = num_replicas
+        self.rank = rank
+        
+    def __iter__(self):
+        max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
+        return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
+         
+    def __len__(self):
+        return len(self.batch_sampler) // self.num_replicas
+            

+ 5 - 8
src/llama_recipes/utils/config_utils.py

@@ -15,7 +15,7 @@ from transformers import default_data_collator
 from transformers.data import DataCollatorWithPadding
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
-from llama_recipes.data.sampler import LengthBasedBatchSampler
+from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 
 
@@ -72,18 +72,15 @@ def get_sampler_kwargs(train_config, dataset, tokenizer, mode):
         kwargs = {}
         batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
         if train_config.enable_fsdp:
-            sampler = DistributedSampler(
+            kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
                 dataset,
+                batch_size=batch_size,
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
             )
-            kwargs["sampler"] = sampler
-            kwargs["batch_size"] = batch_size
-            kwargs["drop_last"] = True
-            kwargs["collate_fn"] = default_data_collator
         else:
-            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, randomize=mode=="train")
-            kwargs["collate_fn"] = DataCollatorWithPadding(tokenizer)
+            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+        kwargs["collate_fn"] = DataCollatorWithPadding(tokenizer)
             
         return kwargs

+ 0 - 57
tests/test_length_based_batch_sampler.py

@@ -1,57 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-import random
-import pytest
-
-import torch
-
-from llama_recipes.data.sampler import LengthBasedBatchSampler
-
-SAMPLES = 33
-
-@pytest.fixture
-def dataset():
-    random.seed(42)
-    dataset = []
-    def add_samples(ds, n, a, b):
-        for _ in range(n):
-            ds.append(random.randint(a,b) * [1,])
-    add_samples(dataset, SAMPLES // 2, 1,9)
-    add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
-    
-    return random.sample(dataset, len(dataset))
-    
-    
-@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
-def test_batch_sampler_array(dataset, batch_size, drop_last):
-    
-    sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
-    
-    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
-    
-    assert len(sampler) == EXPECTED_LENGTH
-    is_long = [len(d)>=10 for d in dataset]
-    
-    def check_batch(batch):
-        return all(batch) or not any(batch)
-    
-    assert all(check_batch(is_long[i] for i in b) for b in sampler)
-    
-    
-@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
-def test_batch_sampler_dict(dataset, batch_size, drop_last):
-    
-    dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
-    
-    sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
-    
-    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
-    
-    assert len(sampler) == EXPECTED_LENGTH
-    is_long = [len(d)>=10 for d in dataset]
-    
-    def check_batch(batch):
-        return all(batch) or not any(batch)
-    
-    assert all(check_batch(is_long[i] for i in b) for b in sampler)