|
@@ -6,7 +6,6 @@ from pkg_resources import packaging
|
|
|
|
|
|
import fire
|
|
|
import torch
|
|
|
-import torch.distributed as dist
|
|
|
import torch.optim as optim
|
|
|
from peft import get_peft_model, prepare_model_for_int8_training
|
|
|
from torch.distributed.fsdp import (
|
|
@@ -14,13 +13,11 @@ from torch.distributed.fsdp import (
|
|
|
)
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
from torch.optim.lr_scheduler import StepLR
|
|
|
-from torch.utils.data import DistributedSampler
|
|
|
from transformers import (
|
|
|
LlamaForCausalLM,
|
|
|
LlamaTokenizer,
|
|
|
LlamaConfig,
|
|
|
- default_data_collator,
|
|
|
-)
|
|
|
+)
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
|
|
|
from llama_recipes.configs import fsdp_config, train_config
|
|
@@ -31,6 +28,7 @@ from llama_recipes.utils.config_utils import (
|
|
|
update_config,
|
|
|
generate_peft_config,
|
|
|
generate_dataset_config,
|
|
|
+ get_sampler_kwargs,
|
|
|
)
|
|
|
from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
|
|
|
|
|
@@ -179,43 +177,24 @@ def main(**kwargs):
|
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
|
print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
|
|
|
- train_sampler = None
|
|
|
- val_sampler = None
|
|
|
- if train_config.enable_fsdp:
|
|
|
- train_sampler = DistributedSampler(
|
|
|
- dataset_train,
|
|
|
- rank=dist.get_rank(),
|
|
|
- num_replicas=dist.get_world_size(),
|
|
|
- shuffle=True,
|
|
|
- )
|
|
|
- if train_config.run_validation:
|
|
|
- val_sampler = DistributedSampler(
|
|
|
- dataset_val,
|
|
|
- rank=dist.get_rank(),
|
|
|
- num_replicas=dist.get_world_size(),
|
|
|
- )
|
|
|
+ train_dl_kwargs = get_sampler_kwargs(train_config, dataset_train, tokenizer, "train")
|
|
|
+ val_dl_kwargs = get_sampler_kwargs(train_config, dataset_val, tokenizer, "val")
|
|
|
|
|
|
# Create DataLoaders for the training and validation dataset
|
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
|
dataset_train,
|
|
|
- batch_size=train_config.batch_size_training,
|
|
|
num_workers=train_config.num_workers_dataloader,
|
|
|
pin_memory=True,
|
|
|
- sampler=train_sampler if train_sampler else None,
|
|
|
- drop_last=True,
|
|
|
- collate_fn=default_data_collator,
|
|
|
+ **train_dl_kwargs,
|
|
|
)
|
|
|
|
|
|
eval_dataloader = None
|
|
|
if train_config.run_validation:
|
|
|
eval_dataloader = torch.utils.data.DataLoader(
|
|
|
dataset_val,
|
|
|
- batch_size=train_config.val_batch_size,
|
|
|
num_workers=train_config.num_workers_dataloader,
|
|
|
pin_memory=True,
|
|
|
- sampler=val_sampler if val_sampler else None,
|
|
|
- drop_last=True,
|
|
|
- collate_fn=default_data_collator,
|
|
|
+ **val_dl_kwargs,
|
|
|
)
|
|
|
|
|
|
# Initialize the optimizer and learning rate scheduler
|