Ver código fonte

cleanup spaces

kldarek 1 ano atrás
pai
commit
f2406cac07

+ 0 - 2
src/llama_recipes/configs/wandb.py

@@ -7,5 +7,3 @@ from dataclasses import dataclass, field
 class wandb_config:
     wandb_project: str='llama_recipes' # wandb project name
     wandb_entity: str='none' # wandb entity name
-    wandb_log_model: bool=False # whether or not to log model as artifact at the end of training
-    wandb_watch: str='false' # can be set to 'gradients' or 'all' to log gradients and parameters

+ 1 - 6
src/llama_recipes/finetuning.py

@@ -11,7 +11,6 @@ import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
-
 )
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
@@ -52,7 +51,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
     except ImportError:
         raise ImportError(
             "You are trying to use wandb which is not currently installed"
-            " Please install it using pip install wandb"
+            "Please install it using pip install wandb"
         )
     from llama_recipes.configs import wandb_config as WANDB_CONFIG
     wandb_config = WANDB_CONFIG()
@@ -68,7 +67,6 @@ def main(**kwargs):
     # Update the configuration for the training and sharding process
     train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
-        
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
@@ -111,7 +109,6 @@ def main(**kwargs):
                 device_map="auto" if train_config.quantization else None,
                 use_cache=use_cache,
             )
-            
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
             llama_config.use_cache = use_cache
@@ -158,7 +155,6 @@ def main(**kwargs):
         if train_config.enable_wandb:
             if not train_config.enable_fsdp or rank==0:
                 wandb_run.config.update(peft_config)
-        
 
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
@@ -271,7 +267,6 @@ def main(**kwargs):
         if train_config.enable_wandb:
             for k,v in results.items():
                 wandb_run.summary[k] = v
-        
 
 if __name__ == "__main__":
     fire.Fire(main)