@@ -85,7 +85,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
elif train_config.batching_strategy == "packing":
if train_config.enable_fsdp:
- kwargs["batch_sampler"] = DistributedSampler(
+ kwargs["sampler"] = DistributedSampler(
dataset,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),