瀏覽代碼

Fix order of concat vs sampler

Matthias Reso 1 年之前
父節點
當前提交
4c225c65eb
共有 1 個文件被更改,包括 5 次插入3 次删除
  1. 5 3
      src/llama_recipes/finetuning.py

+ 5 - 3
src/llama_recipes/finetuning.py

@@ -179,12 +179,11 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-    val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
-
     if train_config.batching_strategy == "packing":
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
+
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -197,6 +196,9 @@ def main(**kwargs):
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
             dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
             num_workers=train_config.num_workers_dataloader,