|
@@ -4,6 +4,7 @@
|
|
|
import os
|
|
|
import sys
|
|
|
from typing import List
|
|
|
+import yaml
|
|
|
|
|
|
import fire
|
|
|
import torch
|
|
@@ -173,7 +174,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
results['avg_eval_prep'] = avg_eval_prep
|
|
|
results['avg_eval_loss'] = avg_eval_loss
|
|
|
|
|
|
-
|
|
|
+ #saving the training params including fsdp setting for reference.
|
|
|
+ if train_config.enable_fsdp and fsdp_config:
|
|
|
+ save_train_params(train_config, fsdp_config, rank)
|
|
|
+
|
|
|
return results
|
|
|
|
|
|
def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
@@ -322,3 +326,42 @@ def get_policies(cfg, rank):
|
|
|
print(f"bFloat16 support not present. Using FP32, and not mixed precision")
|
|
|
wrapping_policy = get_llama_wrapper()
|
|
|
return mixed_precision_policy, wrapping_policy
|
|
|
+
|
|
|
+def save_train_params(train_config, fsdp_config, rank):
|
|
|
+ """
|
|
|
+ This function saves the train_config and FSDP config into a train_params.yaml.
|
|
|
+ This will be used by converter script in the inference folder to fetch the HF model name or path.
|
|
|
+ It also would be hepful as a log for future references.
|
|
|
+ """
|
|
|
+ # Convert the train_config and fsdp_config objects to dictionaries,
|
|
|
+ # converting all values to strings to ensure they can be serialized into a YAML file
|
|
|
+ train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
|
|
|
+ fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
|
|
|
+ # Merge the two dictionaries into one
|
|
|
+ train_params_dict = {**train_config_dict, **fsdp_config_dict}
|
|
|
+ # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
|
|
|
+ folder_name = (
|
|
|
+ train_config.dist_checkpoint_root_folder
|
|
|
+ + "/"
|
|
|
+ + train_config.dist_checkpoint_folder
|
|
|
+ + "-"
|
|
|
+ + train_config.model_name
|
|
|
+ )
|
|
|
+
|
|
|
+ save_dir = Path.cwd() / folder_name
|
|
|
+ # If the directory does not exist, create it
|
|
|
+ if not os.path.exists(save_dir):
|
|
|
+ os.makedirs(save_dir)
|
|
|
+ # Convert the dictionary to a YAML string
|
|
|
+ config_yaml = yaml.dump(train_params_dict, indent=4)
|
|
|
+ file_name = os.path.join(save_dir,'train_params.yaml')
|
|
|
+
|
|
|
+ # Check if there's a directory with the same name as the file
|
|
|
+ if os.path.isdir(file_name):
|
|
|
+ print(f"Error: {file_name} is a directory, not a file.")
|
|
|
+ else:
|
|
|
+ # Write the YAML string to the file
|
|
|
+ with open(file_name, 'w') as f:
|
|
|
+ f.write(config_yaml)
|
|
|
+ if rank==0:
|
|
|
+ print(f"training params are saved in {file_name}")
|