Browse Source

add rank to save_train_params

Hamid Shojanazeri 1 year ago
parent
commit
668c364f6b
1 changed files with 5 additions and 4 deletions
  1. 5 4
      utils/train_utils.py

+ 5 - 4
utils/train_utils.py

@@ -79,7 +79,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             model.train()
             total_loss = 0.0
             data_set_len = 0
-            
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
@@ -177,7 +176,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and fsdp_config:
-        save_train_params(train_config, fsdp_config)
+        save_train_params(train_config, fsdp_config, rank)
         
     return results
 
@@ -328,7 +327,7 @@ def get_policies(cfg, rank):
     wrapping_policy = get_llama_wrapper()
     return mixed_precision_policy, wrapping_policy
 
-def save_train_params(train_config, fsdp_config):
+def save_train_params(train_config, fsdp_config, rank):
     """
     This function saves the train_config and FSDP config into a train_params.yaml.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
@@ -363,4 +362,6 @@ def save_train_params(train_config, fsdp_config):
     else:
         # Write the YAML string to the file
         with open(file_name, 'w') as f:
-            f.write(config_yaml)
+            f.write(config_yaml)
+        if rank==0:
+            print(f"training params are saved in {file_name}")