Bläddra i källkod

fixing wandb for fsdp

kldarek 1 år sedan
förälder
incheckning
fc5485d916
2 ändrade filer med 11 tillägg och 11 borttagningar
  1. 7 6
      src/llama_recipes/finetuning.py
  2. 4 5
      src/llama_recipes/utils/train_utils.py

+ 7 - 6
src/llama_recipes/finetuning.py

@@ -50,7 +50,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
         import wandb
     except ImportError:
         raise ImportError(
-            "You are trying to use wandb which is not currently installed"
+            "You are trying to use wandb which is not currently installed. "
             "Please install it using pip install wandb"
         )
     from llama_recipes.configs import wandb_config as WANDB_CONFIG
@@ -59,7 +59,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
     update_config(wandb_config, **kwargs)
     run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity)
     run.config.update(train_config)
-    run.config.update(fsdp_config)
+    run.config.update(fsdp_config, allow_val_change=True)
     return run
 
     
@@ -84,6 +84,8 @@ def main(**kwargs):
         clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
 
+    wandb_run = None
+
     if train_config.enable_wandb:
         if not train_config.enable_fsdp or rank==0:
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)    
@@ -152,9 +154,8 @@ def main(**kwargs):
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model.print_trainable_parameters()
-        if train_config.enable_wandb:
-            if not train_config.enable_fsdp or rank==0:
-                wandb_run.config.update(peft_config)
+        if wandb_run:
+            wandb_run.config.update(peft_config)
 
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
@@ -260,7 +261,7 @@ def main(**kwargs):
         fsdp_config if train_config.enable_fsdp else None,
         local_rank if train_config.enable_fsdp else None,
         rank if train_config.enable_fsdp else None,
-        wandb_run if train_config.enable_wandb else None,
+        wandb_run,
     )
     if not train_config.enable_fsdp or rank==0:
         [print(f'Key: {k}, Value: {v}') for k, v in results.items()]

+ 4 - 5
src/llama_recipes/utils/train_utils.py

@@ -275,11 +275,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
         print(f" {eval_ppl=} {eval_epoch_loss=}")
 
     if wandb_run: 
-        if not train_config.enable_fsdp or rank==0:
-            wandb_run.log({
-                            'eval/perplexity': eval_ppl,
-                            'eval/loss': eval_epoch_loss,
-                        }, commit=False)
+        wandb_run.log({
+                        'eval/perplexity': eval_ppl,
+                        'eval/loss': eval_epoch_loss,
+                    }, commit=False)
 
     return eval_ppl, eval_epoch_loss