|
@@ -50,7 +50,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
import wandb
|
|
|
except ImportError:
|
|
|
raise ImportError(
|
|
|
- "You are trying to use wandb which is not currently installed"
|
|
|
+ "You are trying to use wandb which is not currently installed. "
|
|
|
"Please install it using pip install wandb"
|
|
|
)
|
|
|
from llama_recipes.configs import wandb_config as WANDB_CONFIG
|
|
@@ -59,7 +59,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
update_config(wandb_config, **kwargs)
|
|
|
run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity)
|
|
|
run.config.update(train_config)
|
|
|
- run.config.update(fsdp_config)
|
|
|
+ run.config.update(fsdp_config, allow_val_change=True)
|
|
|
return run
|
|
|
|
|
|
|
|
@@ -84,6 +84,8 @@ def main(**kwargs):
|
|
|
clear_gpu_cache(local_rank)
|
|
|
setup_environ_flags(rank)
|
|
|
|
|
|
+ wandb_run = None
|
|
|
+
|
|
|
if train_config.enable_wandb:
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
@@ -152,9 +154,8 @@ def main(**kwargs):
|
|
|
peft_config = generate_peft_config(train_config, kwargs)
|
|
|
model = get_peft_model(model, peft_config)
|
|
|
model.print_trainable_parameters()
|
|
|
- if train_config.enable_wandb:
|
|
|
- if not train_config.enable_fsdp or rank==0:
|
|
|
- wandb_run.config.update(peft_config)
|
|
|
+ if wandb_run:
|
|
|
+ wandb_run.config.update(peft_config)
|
|
|
|
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
|
if train_config.enable_fsdp:
|
|
@@ -260,7 +261,7 @@ def main(**kwargs):
|
|
|
fsdp_config if train_config.enable_fsdp else None,
|
|
|
local_rank if train_config.enable_fsdp else None,
|
|
|
rank if train_config.enable_fsdp else None,
|
|
|
- wandb_run if train_config.enable_wandb else None,
|
|
|
+ wandb_run,
|
|
|
)
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
[print(f'Key: {k}, Value: {v}') for k, v in results.items()]
|