소스 검색

Fix pbar update

Matthias Reso 1 년 전
부모
커밋
c33ea3cacb
1개의 변경된 파일3개의 추가작업 그리고 3개의 파일을 삭제
  1. 3 3
      src/llama_recipes/utils/train_utils.py

+ 3 - 3
src/llama_recipes/utils/train_utils.py

@@ -86,15 +86,15 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
-                        pbar.update(gradient_accumulation_steps)
+                        pbar.update(1)
                 else:
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.zero_grad()
-                        pbar.update(gradient_accumulation_steps)
-                
+                        pbar.update(1)
+
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()