Parcourir la source

Fix order of concat vs sampler

Matthias Reso il y a 1 an
Parent
commit
4c225c65eb
1 fichiers modifiés avec 5 ajouts et 3 suppressions
  1. 5 3
      src/llama_recipes/finetuning.py

+ 5 - 3
src/llama_recipes/finetuning.py

@@ -179,12 +179,11 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-    val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
-
     if train_config.batching_strategy == "packing":
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
+
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -197,6 +196,9 @@ def main(**kwargs):
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
             dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
             num_workers=train_config.num_workers_dataloader,