Explorar o código

Fix pbar update

Matthias Reso hai 1 ano
pai
achega
c33ea3cacb
Modificáronse 1 ficheiros con 3 adicións e 3 borrados
  1. 3 3
      src/llama_recipes/utils/train_utils.py

+ 3 - 3
src/llama_recipes/utils/train_utils.py

@@ -86,15 +86,15 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
-                        pbar.update(gradient_accumulation_steps)
+                        pbar.update(1)
                 else:
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.zero_grad()
-                        pbar.update(gradient_accumulation_steps)
-                
+                        pbar.update(1)
+
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()