Browse Source

wandb logging feedback

kldarek 1 year ago
parent
commit
989b6ee812

+ 7 - 3
README.md

@@ -182,9 +182,13 @@ You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.
 You can enable [W&B](https://wandb.ai/) experiment tracking by using `enable_wandb` flag as below. You can change the project name and entity in `wandb_config`. 
 
 ```bash
-python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model --enable_wandb
+python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model --use_wandb
 ```
-
+You'll be able to access a dedicated project or run link on [wandb.ai](https://wandb.ai) and see your dashboard like the one below. 
+<div style="display: flex;">
+    <img src="./docs/images/wandb_screenshot.png" alt="wandb screenshot" width="500" />
+</div>
+ 
 
 # Demo Apps
 This folder contains a series of Llama2-powered apps:
@@ -203,7 +207,7 @@ This folder contains a series of Llama2-powered apps:
 # Repository Organization
 This repository is organized in the following way:
 
-[configs](src/llama_recipes/configs/): Contains the configuration files for PEFT methods, FSDP, Datasets, W&B experiment tracking.
+[configs](src/llama_recipes/configs/): Contains the configuration files for PEFT methods, FSDP, Datasets, Weights & Biases experiment tracking.
 
 [docs](docs/): Example recipes for single and multi-gpu fine-tuning recipes.
 

BIN
docs/images/wandb_screenshot.png


+ 1 - 1
src/llama_recipes/configs/training.py

@@ -36,4 +36,4 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-    enable_wandb: bool = False # add wandb for experient tracking
+    use_wandb: bool = False # Enable wandb for experient tracking

+ 8 - 2
src/llama_recipes/configs/wandb.py

@@ -1,9 +1,15 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from typing import List, Optional
 from dataclasses import dataclass, field
 
 @dataclass
 class wandb_config:
-    wandb_project: str='llama_recipes' # wandb project name
-    wandb_entity: str='none' # wandb entity name
+    project: str = 'llama_recipes' # wandb project name
+    entity: Optional[str] = None # wandb entity name
+    job_type: Optional[str] = None
+    tags: Optional[List[str]] = None
+    group: Optional[str] = None
+    notes: Optional[str] = None
+    mode: Optional[str] = None

+ 5 - 4
src/llama_recipes/finetuning.py

@@ -4,6 +4,7 @@
 import os
 from pkg_resources import packaging
 
+import dataclasses
 import fire
 import random
 import torch
@@ -55,9 +56,9 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
         )
     from llama_recipes.configs import wandb_config as WANDB_CONFIG
     wandb_config = WANDB_CONFIG()
-    wandb_entity = None if wandb_config.wandb_entity == 'none' else wandb_config.wandb_entity
     update_config(wandb_config, **kwargs)
-    run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity)
+    init_dict = dataclasses.asdict(wandb_config)
+    run = wandb.init(**init_dict)
     run.config.update(train_config)
     run.config.update(fsdp_config, allow_val_change=True)
     return run
@@ -86,7 +87,7 @@ def main(**kwargs):
 
     wandb_run = None
 
-    if train_config.enable_wandb:
+    if train_config.use_wandb:
         if not train_config.enable_fsdp or rank==0:
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)    
 
@@ -265,7 +266,7 @@ def main(**kwargs):
     )
     if not train_config.enable_fsdp or rank==0:
         [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
-        if train_config.enable_wandb:
+        if train_config.use_wandb:
             for k,v in results.items():
                 wandb_run.summary[k] = v