Browse Source

Moved sampler to data submodule

Matthias Reso 1 year ago
parent
commit
63ce4ce7f6

+ 2 - 0
src/llama_recipes/data/__init__.py

@@ -0,0 +1,2 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

+ 38 - 0
src/llama_recipes/data/sampler.py

@@ -0,0 +1,38 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+
+import numpy as np
+import torch
+
+
+class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size, drop_last, randomize=True):
+        if isinstance(next(iter(data_source)), dict):
+            first_key = next(iter(next(iter(data_source)).keys()))
+            self.lengths = [len(d[first_key]) for d in data_source]
+        else:
+            self.lengths = [len(d) for d in data_source]
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.randomize = randomize
+
+    def __iter__(self):
+        ids = np.argsort(self.lengths)
+        if self.drop_last:
+            ids = ids[:len(ids) // self.batch_size * self.batch_size]
+
+        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
+
+        if self.randomize:
+            random.shuffle(batches)
+
+        for b in batches:
+            yield b
+
+    def __len__(self):
+        if self.drop_last:
+            return len(self.lengths) // self.batch_size
+        else:
+            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)

+ 1 - 33
src/llama_recipes/datasets/utils.py

@@ -1,12 +1,10 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-import random
 from tqdm import tqdm
 from tqdm import tqdm
 from itertools import chain
 from itertools import chain
 
 
-import numpy as np
-from torch.utils.data import Dataset, BatchSampler
+from torch.utils.data import Dataset
 
 
 
 
 class Concatenator(object):
 class Concatenator(object):
@@ -69,33 +67,3 @@ class ConcatDataset(Dataset):
         return len(self.samples)
         return len(self.samples)
     
     
 
 
-class LengthBasedBatchSampler(BatchSampler):
-    def __init__(self, data_source, batch_size, drop_last, randomize=True):
-        if isinstance(next(iter(data_source)), dict):
-            first_key = next(iter(next(iter(data_source)).keys()))
-            self.lengths = [len(d[first_key]) for d in data_source]
-        else:
-            self.lengths = [len(d) for d in data_source]
-        self.batch_size = batch_size
-        self.drop_last = drop_last
-        self.randomize = randomize
-        
-    def __iter__(self):
-        ids = np.argsort(self.lengths)
-        if self.drop_last:
-            ids = ids[:len(ids) // self.batch_size * self.batch_size]
-            
-        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
-        
-        if self.randomize:
-            random.shuffle(batches)
-            
-        for b in batches:
-            yield b
-        
-    def __len__(self):
-        if self.drop_last:
-            return len(self.lengths) // self.batch_size 
-        else:
-            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
-        

+ 1 - 1
tests/datasets/test_length_based_batch_sampler.py

@@ -3,7 +3,7 @@ import pytest
 
 
 import torch
 import torch
 
 
-from llama_recipes.datasets.utils import LengthBasedBatchSampler
+from llama_recipes.data.sampler import LengthBasedBatchSampler
 
 
 SAMPLES = 33
 SAMPLES = 33