Browse Source

Adjust tests to len based batch sampling

Matthias Reso 1 year ago
parent
commit
ca41c1c697

+ 0 - 1
examples/custom_dataset.py

@@ -86,6 +86,5 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
-    dataset = dataset.map(Concatenator(), batched=True)
     
     return dataset

+ 1 - 2
src/llama_recipes/datasets/samsum_dataset.py

@@ -27,7 +27,6 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
         
     dataset = dataset.map(
         lambda sample: tokenizer(sample["text"]),
-        batched=True,
         remove_columns=list(dataset.features),
-    ).map(Concatenator(), batched=True)
+    )
     return dataset

+ 6 - 27
src/llama_recipes/finetuning.py

@@ -6,7 +6,6 @@ from pkg_resources import packaging
 
 import fire
 import torch
-import torch.distributed as dist
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
@@ -14,13 +13,11 @@ from torch.distributed.fsdp import (
 )
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
-from torch.utils.data import DistributedSampler
 from transformers import (
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaConfig,
-    default_data_collator,
-)
+)   
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 from llama_recipes.configs import fsdp_config, train_config
@@ -31,6 +28,7 @@ from llama_recipes.utils.config_utils import (
     update_config,
     generate_peft_config,
     generate_dataset_config,
+    get_sampler_kwargs,
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
@@ -179,43 +177,24 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_sampler = None
-    val_sampler = None
-    if train_config.enable_fsdp:
-        train_sampler = DistributedSampler(
-            dataset_train,
-            rank=dist.get_rank(),
-            num_replicas=dist.get_world_size(),
-            shuffle=True,
-        )
-        if train_config.run_validation:
-            val_sampler = DistributedSampler(
-                dataset_val,
-                rank=dist.get_rank(),
-                num_replicas=dist.get_world_size(),
-            )
+    train_dl_kwargs = get_sampler_kwargs(train_config, dataset_train, tokenizer, "train")
+    val_dl_kwargs = get_sampler_kwargs(train_config, dataset_val, tokenizer, "val")
 
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
-        batch_size=train_config.batch_size_training,
         num_workers=train_config.num_workers_dataloader,
         pin_memory=True,
-        sampler=train_sampler if train_sampler else None,
-        drop_last=True,
-        collate_fn=default_data_collator,
+        **train_dl_kwargs,
     )
 
     eval_dataloader = None
     if train_config.run_validation:
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
-            batch_size=train_config.val_batch_size,
             num_workers=train_config.num_workers_dataloader,
             pin_memory=True,
-            sampler=val_sampler if val_sampler else None,
-            drop_last=True,
-            collate_fn=default_data_collator,
+            **val_dl_kwargs,
         )
 
     # Initialize the optimizer and learning rate scheduler

+ 28 - 1
src/llama_recipes/utils/config_utils.py

@@ -3,13 +3,19 @@
 
 import inspect
 from dataclasses import asdict
+
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
 from peft import (
     LoraConfig,
     AdaptionPromptConfig,
     PrefixTuningConfig,
 )
+from transformers import default_data_collator
+from transformers.data import DataCollatorWithPadding
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
+from llama_recipes.data.sampler import LengthBasedBatchSampler
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 
 
@@ -59,4 +65,25 @@ def generate_dataset_config(train_config, kwargs):
         
     update_config(dataset_config, **kwargs)
     
-    return  dataset_config
+    return  dataset_config
+
+
+def get_sampler_kwargs(train_config, dataset, tokenizer, mode):
+        kwargs = {}
+        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+        if train_config.enable_fsdp:
+            sampler = DistributedSampler(
+                dataset,
+                rank=dist.get_rank(),
+                num_replicas=dist.get_world_size(),
+                shuffle=mode=="train",
+            )
+            kwargs["sampler"] = sampler
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
+        else:
+            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, randomize=mode=="train")
+            kwargs["collate_fn"] = DataCollatorWithPadding(tokenizer)
+            
+        return kwargs

+ 9 - 11
tests/datasets/test_custom_dataset.py

@@ -18,6 +18,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
+        "val_batch_size": 4,
         "use_peft": False,
         }
 
@@ -30,24 +31,21 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     eval_dataloader = args[2]
     tokenizer = args[3]
 
-    assert len(train_dataloader) == 226
-    assert len(eval_dataloader) == 2*226
+    assert len(train_dataloader) == 1120
+    assert len(eval_dataloader) == 1120 //2
 
-    it = iter(train_dataloader)
+    it = iter(eval_dataloader)
     STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
-
+    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
     assert STRING.startswith(EXPECTED_STRING)
+    
+    assert next(it)["input_ids"].size(0) == 4
 
-    next(it)
     next(it)
     next(it)
     STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_SUBSTRING_1 = "Therefore you are correct.  [INST] How can L’Hopital’s Rule be"
-    EXPECTED_SUBSTRING_2 = "a circular path around the turn.  [INST] How on earth is that related to L’Hopital’s Rule?"
-
-    assert EXPECTED_SUBSTRING_1 in STRING
-    assert EXPECTED_SUBSTRING_2 in STRING
+    EXPECTED_STRING = "[INST] Implementa el algoritmo `bubble sort` en C. [/INST] xdxdxd"
+    assert STRING.startswith(EXPECTED_STRING)
 
 
 @patch('llama_recipes.finetuning.train')

+ 6 - 5
tests/datasets/test_samsum_datasets.py

@@ -9,14 +9,15 @@ from unittest.mock import patch
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_samsum_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
     from llama_recipes.finetuning import main
         
     tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
     
-    
+    BATCH_SIZE = 8
     kwargs = {
-        "batch_size_training": 1,
+        "batch_size_training": 8,
+        "val_batch_size": 1,
         "use_peft": False,
         "dataset": "samsum_dataset",
         }
@@ -31,7 +32,7 @@ def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker)
     
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
-    CONCAT_SIZE = 2048
-    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+    
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
     

+ 4 - 4
tests/test_finetuning.py

@@ -19,7 +19,7 @@ from llama_recipes.finetuning import main
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
     
-    get_dataset.return_value = [1]
+    get_dataset.return_value = [[1]]
     
     main(**kwargs)
     
@@ -43,7 +43,7 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
-    get_dataset.return_value = [1]
+    get_dataset.return_value = [[1]]
     
     main(**kwargs)
     
@@ -69,7 +69,7 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
     
-    get_dataset.return_value = [1]
+    get_dataset.return_value = [[1]]
     
     main(**kwargs)
     
@@ -86,7 +86,7 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
     kwargs = {"weight_decay": 0.01}
     
-    get_dataset.return_value = [1]
+    get_dataset.return_value = [[1]]
     
     get_peft_model.return_value = Linear(1,1)
     get_peft_model.return_value.print_trainable_parameters=lambda:None