Browse Source

add rank to save_train_params

Hamid Shojanazeri 1 year ago
parent
commit
668c364f6b
1 changed files with 5 additions and 4 deletions
  1. 5 4
      utils/train_utils.py

+ 5 - 4
utils/train_utils.py

@@ -79,7 +79,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             model.train()
             model.train()
             total_loss = 0.0
             total_loss = 0.0
             data_set_len = 0
             data_set_len = 0
-            
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
@@ -177,7 +176,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         
     #saving the training params including fsdp setting for reference.
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and fsdp_config:
     if train_config.enable_fsdp and fsdp_config:
-        save_train_params(train_config, fsdp_config)
+        save_train_params(train_config, fsdp_config, rank)
         
         
     return results
     return results
 
 
@@ -328,7 +327,7 @@ def get_policies(cfg, rank):
     wrapping_policy = get_llama_wrapper()
     wrapping_policy = get_llama_wrapper()
     return mixed_precision_policy, wrapping_policy
     return mixed_precision_policy, wrapping_policy
 
 
-def save_train_params(train_config, fsdp_config):
+def save_train_params(train_config, fsdp_config, rank):
     """
     """
     This function saves the train_config and FSDP config into a train_params.yaml.
     This function saves the train_config and FSDP config into a train_params.yaml.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
@@ -363,4 +362,6 @@ def save_train_params(train_config, fsdp_config):
     else:
     else:
         # Write the YAML string to the file
         # Write the YAML string to the file
         with open(file_name, 'w') as f:
         with open(file_name, 'w') as f:
-            f.write(config_yaml)
+            f.write(config_yaml)
+        if rank==0:
+            print(f"training params are saved in {file_name}")