Quellcode durchsuchen

Remove print as it breaks progress bar and update progress bar description instead

Matthias Reso vor 1 Jahr
Ursprung
Commit
47ae6d0326
1 geänderte Dateien mit 8 neuen und 6 gelöschten Zeilen
  1. 8 6
      utils/train_utils.py

+ 8 - 6
utils/train_utils.py

@@ -83,7 +83,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         with MemoryTrace() as memtrace:  # track the memory usage
             model.train()
             total_loss = 0.0
-            for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
+            total_length = len(train_dataloader)//gradient_accumulation_steps
+            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length)
+            for step, batch in enumerate(train_dataloader):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
@@ -99,17 +101,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
+                        pbar.update(step//gradient_accumulation_steps)
                 else:
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.zero_grad()
-                if train_config.enable_fsdp:
-                    if rank==0:       
-                        print(f"\n step {step} is completed and loss is {loss.detach().float()}")
-                else:
-                    print(f"\n step {step} is completed and loss is {loss.detach().float()}")
+                        pbar.update(step//gradient_accumulation_steps)
+                
+                pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+                
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)    
         # Reducing total_loss across all devices if there's more than one CUDA device