Prechádzať zdrojové kódy

fix the save_train_param condition

Hamid Shojanazeri 1 rok pred
rodič
commit
88d3e1febc
1 zmenil súbory, kde vykonal 1 pridanie a 1 odobranie
  1. 1 1
      utils/train_utils.py

+ 1 - 1
utils/train_utils.py

@@ -206,7 +206,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
         
     #saving the training params including fsdp setting for reference.
-    if train_config.enable_fsdp and fsdp_config:
+    if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)
         
     return results