소스 검색

enable grad on loss tensor

abhilash1910 1 년 전
부모
커밋
81fecf3d4b
1개의 변경된 파일1개의 추가작업 그리고 0개의 파일을 삭제
  1. 1 0
      src/llama_recipes/utils/train_utils.py

+ 1 - 0
src/llama_recipes/utils/train_utils.py

@@ -79,6 +79,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
+                loss = torch.autograd.Variable(loss, required_grad = True)
                 if train_config.use_fp16:
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()