Bläddra i källkod

pass weight_decay into optimizer

Shijie Wu 1 år sedan
förälder
incheckning
91e2573aa8
1 ändrade filer med 2 tillägg och 1 borttagningar
  1. 2 1
      src/llama_recipes/finetuning.py

+ 2 - 1
src/llama_recipes/finetuning.py

@@ -226,12 +226,13 @@ def main(**kwargs):
             momentum_dtype=torch.bfloat16,
             variance_dtype=torch.bfloat16,
             use_kahan_summation=False,
+            weight_decay=train_config.weight_decay,
         )
     else:
         optimizer = optim.AdamW(
             model.parameters(),
             lr=train_config.lr,
-            weight_decay=0.0,
+            weight_decay=train_config.weight_decay,
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)