Browse Source

adding train_param.yaml saving for fsdp checkpoint loading for inference

Hamid Shojanazeri 1 year ago
parent
commit
231c9e7da9
1 changed files with 42 additions and 1 deletions
  1. 42 1
      utils/train_utils.py

+ 42 - 1
utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import sys
 from typing import List
+import yaml
 
 import fire
 import torch
@@ -174,7 +175,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_prep'] = avg_eval_prep
         results['avg_eval_loss'] = avg_eval_loss
         
-
+    #saving the training params including fsdp setting for reference.
+    if train_config.enable_fsdp and fsdp_config:
+        save_train_params(train_config, fsdp_config)
+        
     return results
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
@@ -323,3 +327,40 @@ def get_policies(cfg, rank):
             print(f"bFloat16 support not present. Using FP32, and not mixed precision")
     wrapping_policy = get_llama_wrapper()
     return mixed_precision_policy, wrapping_policy
+
+def save_train_params(train_config, fsdp_config):
+    """
+    This function saves the train_config and FSDP config into a train_params.yaml.
+    This will be used by converter script in the inference folder to fetch the HF model name or path.
+    It also would be hepful as a log for future references.
+    """
+    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # converting all values to strings to ensure they can be serialized into a YAML file
+    train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
+    fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
+    # Merge the two dictionaries into one
+    train_params_dict = {**train_config_dict, **fsdp_config_dict}
+    # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
+    folder_name = (
+    train_config.dist_checkpoint_root_folder
+    + "/"
+    + train_config.dist_checkpoint_folder
+    + "-"
+    + train_config.model_name
+    )
+
+    save_dir = Path.cwd() / folder_name
+    # If the directory does not exist, create it
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    # Convert the dictionary to a YAML string
+    config_yaml = yaml.dump(train_params_dict, indent=4)
+    file_name = os.path.join(save_dir,'train_params.yaml')
+
+    # Check if there's a directory with the same name as the file
+    if os.path.isdir(file_name):
+        print(f"Error: {file_name} is a directory, not a file.")
+    else:
+        # Write the YAML string to the file
+        with open(file_name, 'w') as f:
+            f.write(config_yaml)