Prechádzať zdrojové kódy

basic wandb logging instrumentation

kldarek 1 rok pred
rodič
commit
cf373529f7

+ 3 - 0
.gitignore

@@ -1,3 +1,6 @@
 .DS_Store
 __pycache__
 .ipynb_checkpoints
+.gitignore
+wandb/
+artifacts/

+ 1 - 0
src/llama_recipes/configs/__init__.py

@@ -4,3 +4,4 @@
 from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config
 from llama_recipes.configs.fsdp import fsdp_config
 from llama_recipes.configs.training import train_config
+from llama_recipes.configs.wandb import wandb_config

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -36,3 +36,4 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    enable_wandb: bool = False # add wandb for experient tracking

+ 11 - 0
src/llama_recipes/configs/wandb.py

@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from dataclasses import dataclass, field
+
+@dataclass
+class wandb_config:
+    wandb_project: str='llama_recipes' # wandb project name
+    wandb_entity: str='none' # wandb entity name
+    wandb_log_model: bool=False # whether or not to log model as artifact at the end of training
+    wandb_watch: str='false' # can be set to 'gradients' or 'all' to log gradients and parameters

+ 34 - 2
src/llama_recipes/finetuning.py

@@ -11,6 +11,7 @@ import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
+
 )
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
@@ -45,12 +46,29 @@ from llama_recipes.utils.train_utils import (
     get_policies
 )
 
-
+def setup_wandb(train_config, fsdp_config, **kwargs):
+    try: 
+        import wandb
+    except ImportError:
+        raise ImportError(
+            "You are trying to use wandb which is not currently installed"
+            " Please install it using pip install wandb"
+        )
+    from llama_recipes.configs import wandb_config as WANDB_CONFIG
+    wandb_config = WANDB_CONFIG()
+    wandb_entity = None if wandb_config.wandb_entity == 'none' else wandb_config.wandb_entity
+    update_config(wandb_config, **kwargs)
+    run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity)
+    run.config.update(train_config)
+    run.config.update(fsdp_config)
+    return run
+
+    
 def main(**kwargs):
     # Update the configuration for the training and sharding process
     train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
-
+        
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
@@ -68,6 +86,10 @@ def main(**kwargs):
         clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
 
+    if train_config.enable_wandb:
+        if not train_config.enable_fsdp or rank==0:
+            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)    
+
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
     if train_config.enable_fsdp and train_config.low_cpu_fsdp:
@@ -89,6 +111,7 @@ def main(**kwargs):
                 device_map="auto" if train_config.quantization else None,
                 use_cache=use_cache,
             )
+            
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
             llama_config.use_cache = use_cache
@@ -132,6 +155,10 @@ def main(**kwargs):
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model.print_trainable_parameters()
+        if train_config.enable_wandb:
+            if not train_config.enable_fsdp or rank==0:
+                wandb_run.config.update(peft_config)
+        
 
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
@@ -237,9 +264,14 @@ def main(**kwargs):
         fsdp_config if train_config.enable_fsdp else None,
         local_rank if train_config.enable_fsdp else None,
         rank if train_config.enable_fsdp else None,
+        wandb_run if train_config.enable_wandb else None,
     )
     if not train_config.enable_fsdp or rank==0:
         [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
+        if train_config.enable_wandb:
+            for k,v in results.items():
+                wandb_run.summary[k] = v
+        
 
 if __name__ == "__main__":
     fire.Fire(main)

+ 18 - 3
src/llama_recipes/utils/train_utils.py

@@ -31,7 +31,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def byte2mb(x):
     return int(x / 2**20)
 
-def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
+def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
     """
     Trains the model on the given dataloader
 
@@ -99,6 +99,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         optimizer.zero_grad()
                         pbar.update(1)
 
+                if wandb_run: 
+                    if not train_config.enable_fsdp or rank==0:
+                        wandb_run.log({
+                            'train/epoch': epoch + 1,
+                            'train/step': epoch * len(train_dataloader) + step,
+                            'train/loss': loss.detach().float(),
+                        })
+
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()
 
@@ -133,7 +141,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         lr_scheduler.step()
 
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
             checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp:
@@ -213,7 +221,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
     return results
 
-def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
+def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
     """
     Evaluates the model on the given dataloader
 
@@ -266,6 +274,13 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
 
+    if wandb_run: 
+        if not train_config.enable_fsdp or rank==0:
+            wandb_run.log({
+                            'eval/perplexity': eval_ppl,
+                            'eval/loss': eval_epoch_loss,
+                        }, commit=False)
+
     return eval_ppl, eval_epoch_loss
 
 def freeze_transformer_layers(model, num_layer):