Explorar o código

Adding config to conditionally save stats

Beto hai 1 ano
pai
achega
17d02c3b44
Modificáronse 2 ficheiros con 24 adicións e 11 borrados
  1. 1 0
      src/llama_recipes/configs/training.py
  2. 23 11
      src/llama_recipes/utils/train_utils.py

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -36,3 +36,4 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    save_metrics: bool = False # saves training metrics to a json file for later plotting

+ 23 - 11
src/llama_recipes/utils/train_utils.py

@@ -59,19 +59,23 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     if train_config.enable_fsdp:
         world_size = int(os.environ["WORLD_SIZE"]) 
 
-    metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
+    
+
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
 
     train_prep = []
     train_loss = []
-    train_step_perplexity = []
-    train_step_loss = []
-
     val_prep = []
     val_loss =[]
-    val_step_loss = []
-    val_step_perplexity = []
 
+    if train_config.save_metrics:
+        print(f"Save metrics TRUE {train_config.save_metrics}")
+        metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
+        train_step_perplexity = []
+        train_step_loss = []
+        val_step_loss = []
+        val_step_perplexity = []
+        
     epoch_times = []
     checkpoint_times = []
     results = {}
@@ -112,6 +116,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         pbar.update(1)
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+
+                if train_config.save_metrics:
+                    save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
             pbar.close()
 
         epoch_end_time = time.perf_counter()-epoch_start_time
@@ -146,8 +153,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
-            val_step_loss.extend(temp_val_loss)
-            val_step_perplexity.extend(temp_step_perplexity)
+            if train_config.save_metrics:
+                val_step_loss.extend(temp_val_loss)
+                val_step_perplexity.extend(temp_step_perplexity)
+
             checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp:
@@ -207,7 +216,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
         
         # Saving the results every epoch to plot later
-        save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+        if train_config.save_metrics:
+            save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
 
     avg_epoch_time = sum(epoch_times)/ len(epoch_times)
     avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
@@ -262,8 +272,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 # Forward pass and compute loss
                 outputs = model(**batch)
                 loss = outputs.loss
-                val_step_loss.append(loss.detach().float().item())
-                val_step_perplexity.append(float(torch.exp(loss.detach().float())))  
+                if train_config.save_metrics:
+                    val_step_loss.append(loss.detach().float().item())
+                    val_step_perplexity.append(float(torch.exp(loss.detach().float())))  
+
                 eval_loss += loss.detach().float()
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)