|
@@ -179,12 +179,11 @@ def main(**kwargs):
|
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
|
print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
|
|
|
- train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
|
|
|
- val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
|
|
|
-
|
|
|
if train_config.batching_strategy == "packing":
|
|
|
dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|
|
|
|
|
|
+ train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
|
|
|
+
|
|
|
# Create DataLoaders for the training and validation dataset
|
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
|
dataset_train,
|
|
@@ -197,6 +196,9 @@ def main(**kwargs):
|
|
|
if train_config.run_validation:
|
|
|
if train_config.batching_strategy == "packing":
|
|
|
dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
|
|
|
+
|
|
|
+ val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
|
|
|
+
|
|
|
eval_dataloader = torch.utils.data.DataLoader(
|
|
|
dataset_val,
|
|
|
num_workers=train_config.num_workers_dataloader,
|