|
@@ -85,7 +85,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
|
|
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
|
|
kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
|
|
elif train_config.batching_strategy == "packing":
|
|
elif train_config.batching_strategy == "packing":
|
|
if train_config.enable_fsdp:
|
|
if train_config.enable_fsdp:
|
|
- kwargs["batch_sampler"] = DistributedSampler(
|
|
|
|
|
|
+ kwargs["sampler"] = DistributedSampler(
|
|
dataset,
|
|
dataset,
|
|
rank=dist.get_rank(),
|
|
rank=dist.get_rank(),
|
|
num_replicas=dist.get_world_size(),
|
|
num_replicas=dist.get_world_size(),
|