Просмотр исходного кода

Adding W&B (wandb) experiment tracking (#304)

Hamid Shojanazeri 1 год назад
Родитель
Сommit
2e664f45e9

+ 2 - 0
.gitignore

@@ -1,3 +1,5 @@
 .DS_Store
 __pycache__
 .ipynb_checkpoints
+wandb/
+artifacts/

+ 13 - 1
README.md

@@ -203,6 +203,18 @@ sbatch multi_node.slurm
 ```
 You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.md).
 
+## Weights & Biases Experiment Tracking
+
+You can enable [W&B](https://wandb.ai/) experiment tracking by using `use_wandb` flag as below. You can change the project name, entity and other `wandb.init` arguments in `wandb_config`.
+
+```bash
+python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model --use_wandb
+```
+You'll be able to access a dedicated project or run link on [wandb.ai](https://wandb.ai) and see your dashboard like the one below. 
+<div style="display: flex;">
+    <img src="./docs/images/wandb_screenshot.png" alt="wandb screenshot" width="500" />
+</div>
+ 
 # Evaluation Harness
 
 Here, we make use `lm-evaluation-harness` from `EleutherAI` for evaluation of fine-tuned Llama 2 models. This also can extend to evaluate other optimizations for inference of Llama 2 model such as quantization. Please use this get started [doc](./eval/README.md).
@@ -234,7 +246,7 @@ This folder contains a series of benchmark scripts for Llama 2 models inference
 This repository is organized in the following way:
 [benchmarks](./benchmarks): Contains a series of benchmark scripts for Llama 2 models inference on various backends.
 
-[configs](src/llama_recipes/configs/): Contains the configuration files for PEFT methods, FSDP, Datasets.
+[configs](src/llama_recipes/configs/): Contains the configuration files for PEFT methods, FSDP, Datasets, Weights & Biases experiment tracking.
 
 [docs](docs/): Example recipes for single and multi-gpu fine-tuning recipes.
 

BIN
docs/images/wandb_screenshot.png


+ 7 - 0
scripts/spellcheck_conf/wordlist.txt

@@ -1251,3 +1251,10 @@ lm
 prepended
 subtasks
 EleutherAI
+CodeLlama
+LlamaGuard
+OctoAI
+OctoAI's
+PurpleLlama
+Youtube
+wandb

+ 1 - 0
src/llama_recipes/configs/__init__.py

@@ -4,3 +4,4 @@
 from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config
 from llama_recipes.configs.fsdp import fsdp_config
 from llama_recipes.configs.training import train_config
+from llama_recipes.configs.wandb import wandb_config

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -38,4 +38,5 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    use_wandb: bool = False # Enable wandb for experient tracking
     save_metrics: bool = False # saves training metrics to a json file for later plotting

+ 15 - 0
src/llama_recipes/configs/wandb.py

@@ -0,0 +1,15 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from typing import List, Optional
+from dataclasses import dataclass, field
+
+@dataclass
+class wandb_config:
+    project: str = 'llama_recipes' # wandb project name
+    entity: Optional[str] = None # wandb entity name
+    job_type: Optional[str] = None
+    tags: Optional[List[str]] = None
+    group: Optional[str] = None
+    notes: Optional[str] = None
+    mode: Optional[str] = None

+ 32 - 1
src/llama_recipes/finetuning.py

@@ -4,6 +4,7 @@
 import os
 from pkg_resources import packaging
 
+import dataclasses
 import fire
 import random
 import torch
@@ -49,11 +50,28 @@ from llama_recipes.utils.train_utils import (
 )
 from accelerate.utils import is_xpu_available
 
+def setup_wandb(train_config, fsdp_config, **kwargs):
+    try: 
+        import wandb
+    except ImportError:
+        raise ImportError(
+            "You are trying to use wandb which is not currently installed. "
+            "Please install it using pip install wandb"
+        )
+    from llama_recipes.configs import wandb_config as WANDB_CONFIG
+    wandb_config = WANDB_CONFIG()
+    update_config(wandb_config, **kwargs)
+    init_dict = dataclasses.asdict(wandb_config)
+    run = wandb.init(**init_dict)
+    run.config.update(train_config)
+    run.config.update(fsdp_config, allow_val_change=True)
+    return run
+
+
 def main(**kwargs):
     # Update the configuration for the training and sharding process
     train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
-
     # Set the seeds for reproducibility
     if is_xpu_available():
         torch.xpu.manual_seed(train_config.seed)
@@ -75,6 +93,12 @@ def main(**kwargs):
         clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
 
+    wandb_run = None
+
+    if train_config.use_wandb:
+        if not train_config.enable_fsdp or rank==0:
+            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)    
+
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
     if train_config.enable_fsdp and train_config.low_cpu_fsdp:
@@ -130,6 +154,9 @@ def main(**kwargs):
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model.print_trainable_parameters()
+        if wandb_run:
+            wandb_run.config.update(peft_config)
+
         
     hsdp_device_mesh = None
     if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
@@ -250,9 +277,13 @@ def main(**kwargs):
         fsdp_config if train_config.enable_fsdp else None,
         local_rank if train_config.enable_fsdp else None,
         rank if train_config.enable_fsdp else None,
+        wandb_run,
     )
     if not train_config.enable_fsdp or rank==0:
         [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
+        if train_config.use_wandb:
+            for k,v in results.items():
+                wandb_run.summary[k] = v
 
 if __name__ == "__main__":
     fire.Fire(main)

+ 18 - 4
src/llama_recipes/utils/train_utils.py

@@ -33,7 +33,7 @@ def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def byte2mb(x):
     return int(x / 2**20)
 
-def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
+def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
     """
     Trains the model on the given dataloader
 
@@ -133,6 +133,14 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         optimizer.zero_grad()
                         pbar.update(1)
 
+                if wandb_run: 
+                    if not train_config.enable_fsdp or rank==0:
+                        wandb_run.log({
+                            'train/epoch': epoch + 1,
+                            'train/step': epoch * len(train_dataloader) + step,
+                            'train/loss': loss.detach().float(),
+                        })
+
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
 
                 if train_config.save_metrics:
@@ -161,7 +169,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         lr_scheduler.step()
 
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
             if train_config.save_metrics:
                 val_step_loss.extend(temp_val_loss)
                 val_step_perplexity.extend(temp_step_perplexity)
@@ -252,7 +260,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
     return results
 
-def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
+def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
     """
     Evaluates the model on the given dataloader
 
@@ -315,6 +323,12 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             print(f" {eval_ppl=} {eval_epoch_loss=}")
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
+
+    if wandb_run: 
+        wandb_run.log({
+                        'eval/perplexity': eval_ppl,
+                        'eval/loss': eval_epoch_loss,
+                    }, commit=False)
         
     return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
 
@@ -478,4 +492,4 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
         "val_epoch_perplexity": val_epoch_ppl
     }
     with open(output_filename, "w") as f:
-        json.dump(metrics_data, f)
+        json.dump(metrics_data, f)