|
@@ -11,6 +11,7 @@ import torch.optim as optim
|
|
|
from peft import get_peft_model, prepare_model_for_int8_training
|
|
|
from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
|
+
|
|
|
)
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
from torch.optim.lr_scheduler import StepLR
|
|
@@ -45,12 +46,29 @@ from llama_recipes.utils.train_utils import (
|
|
|
get_policies
|
|
|
)
|
|
|
|
|
|
-
|
|
|
+def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
+ try:
|
|
|
+ import wandb
|
|
|
+ except ImportError:
|
|
|
+ raise ImportError(
|
|
|
+ "You are trying to use wandb which is not currently installed"
|
|
|
+ " Please install it using pip install wandb"
|
|
|
+ )
|
|
|
+ from llama_recipes.configs import wandb_config as WANDB_CONFIG
|
|
|
+ wandb_config = WANDB_CONFIG()
|
|
|
+ wandb_entity = None if wandb_config.wandb_entity == 'none' else wandb_config.wandb_entity
|
|
|
+ update_config(wandb_config, **kwargs)
|
|
|
+ run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity)
|
|
|
+ run.config.update(train_config)
|
|
|
+ run.config.update(fsdp_config)
|
|
|
+ return run
|
|
|
+
|
|
|
+
|
|
|
def main(**kwargs):
|
|
|
# Update the configuration for the training and sharding process
|
|
|
train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
|
|
|
update_config((train_config, fsdp_config), **kwargs)
|
|
|
-
|
|
|
+
|
|
|
# Set the seeds for reproducibility
|
|
|
torch.cuda.manual_seed(train_config.seed)
|
|
|
torch.manual_seed(train_config.seed)
|
|
@@ -68,6 +86,10 @@ def main(**kwargs):
|
|
|
clear_gpu_cache(local_rank)
|
|
|
setup_environ_flags(rank)
|
|
|
|
|
|
+ if train_config.enable_wandb:
|
|
|
+ if not train_config.enable_fsdp or rank==0:
|
|
|
+ wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
|
+
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|
|
|
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
|
|
@@ -89,6 +111,7 @@ def main(**kwargs):
|
|
|
device_map="auto" if train_config.quantization else None,
|
|
|
use_cache=use_cache,
|
|
|
)
|
|
|
+
|
|
|
else:
|
|
|
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
|
llama_config.use_cache = use_cache
|
|
@@ -132,6 +155,10 @@ def main(**kwargs):
|
|
|
peft_config = generate_peft_config(train_config, kwargs)
|
|
|
model = get_peft_model(model, peft_config)
|
|
|
model.print_trainable_parameters()
|
|
|
+ if train_config.enable_wandb:
|
|
|
+ if not train_config.enable_fsdp or rank==0:
|
|
|
+ wandb_run.config.update(peft_config)
|
|
|
+
|
|
|
|
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
|
if train_config.enable_fsdp:
|
|
@@ -237,9 +264,14 @@ def main(**kwargs):
|
|
|
fsdp_config if train_config.enable_fsdp else None,
|
|
|
local_rank if train_config.enable_fsdp else None,
|
|
|
rank if train_config.enable_fsdp else None,
|
|
|
+ wandb_run if train_config.enable_wandb else None,
|
|
|
)
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
[print(f'Key: {k}, Value: {v}') for k, v in results.items()]
|
|
|
+ if train_config.enable_wandb:
|
|
|
+ for k,v in results.items():
|
|
|
+ wandb_run.summary[k] = v
|
|
|
+
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(main)
|