소스 검색

bugfix: update tqdm bar with the fixed gradient_accumulation_steps

hongbo.mo 1 년 전
부모
커밋
0bc6a07a80
1개의 변경된 파일4개의 추가작업 그리고 3개의 파일을 삭제
  1. 4 3
      src/llama_recipes/utils/train_utils.py

+ 4 - 3
src/llama_recipes/utils/train_utils.py

@@ -86,16 +86,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
-                        pbar.update(step//gradient_accumulation_steps)
+                        pbar.update(gradient_accumulation_steps)
                 else:
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.zero_grad()
-                        pbar.update(step//gradient_accumulation_steps)
+                        pbar.update(gradient_accumulation_steps)
                 
                 pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+            pbar.close()
                 
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)    
@@ -400,4 +401,4 @@ def save_train_params(train_config, fsdp_config, rank):
         with open(file_name, 'w') as f:
             f.write(config_yaml)
         if rank==0:
-            print(f"training params are saved in {file_name}")
+            print(f"training params are saved in {file_name}")