|
@@ -11,7 +11,6 @@ import torch.optim as optim
|
|
|
from peft import get_peft_model, prepare_model_for_int8_training
|
|
|
from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
|
-
|
|
|
)
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
from torch.optim.lr_scheduler import StepLR
|
|
@@ -52,7 +51,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
except ImportError:
|
|
|
raise ImportError(
|
|
|
"You are trying to use wandb which is not currently installed"
|
|
|
- " Please install it using pip install wandb"
|
|
|
+ "Please install it using pip install wandb"
|
|
|
)
|
|
|
from llama_recipes.configs import wandb_config as WANDB_CONFIG
|
|
|
wandb_config = WANDB_CONFIG()
|
|
@@ -68,7 +67,6 @@ def main(**kwargs):
|
|
|
# Update the configuration for the training and sharding process
|
|
|
train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
|
|
|
update_config((train_config, fsdp_config), **kwargs)
|
|
|
-
|
|
|
# Set the seeds for reproducibility
|
|
|
torch.cuda.manual_seed(train_config.seed)
|
|
|
torch.manual_seed(train_config.seed)
|
|
@@ -111,7 +109,6 @@ def main(**kwargs):
|
|
|
device_map="auto" if train_config.quantization else None,
|
|
|
use_cache=use_cache,
|
|
|
)
|
|
|
-
|
|
|
else:
|
|
|
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
|
llama_config.use_cache = use_cache
|
|
@@ -158,7 +155,6 @@ def main(**kwargs):
|
|
|
if train_config.enable_wandb:
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
wandb_run.config.update(peft_config)
|
|
|
-
|
|
|
|
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
|
if train_config.enable_fsdp:
|
|
@@ -271,7 +267,6 @@ def main(**kwargs):
|
|
|
if train_config.enable_wandb:
|
|
|
for k,v in results.items():
|
|
|
wandb_run.summary[k] = v
|
|
|
-
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(main)
|