Procházet zdrojové kódy

Fix sampler vs batch_sampler

Matthias Reso před 1 rokem
rodič
revize
5a359b7bf2
1 změnil soubory, kde provedl 1 přidání a 1 odebrání
  1. 1 1
      src/llama_recipes/utils/config_utils.py

+ 1 - 1
src/llama_recipes/utils/config_utils.py

@@ -85,7 +85,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
             kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
         elif train_config.batching_strategy == "packing":
             if train_config.enable_fsdp:
-                kwargs["batch_sampler"] = DistributedSampler(
+                kwargs["sampler"] = DistributedSampler(
                 dataset,
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),