|
@@ -79,7 +79,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
|
data_set_len = 0
|
|
|
-
|
|
|
for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
@@ -177,7 +176,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
#saving the training params including fsdp setting for reference.
|
|
|
if train_config.enable_fsdp and fsdp_config:
|
|
|
- save_train_params(train_config, fsdp_config)
|
|
|
+ save_train_params(train_config, fsdp_config, rank)
|
|
|
|
|
|
return results
|
|
|
|
|
@@ -328,7 +327,7 @@ def get_policies(cfg, rank):
|
|
|
wrapping_policy = get_llama_wrapper()
|
|
|
return mixed_precision_policy, wrapping_policy
|
|
|
|
|
|
-def save_train_params(train_config, fsdp_config):
|
|
|
+def save_train_params(train_config, fsdp_config, rank):
|
|
|
"""
|
|
|
This function saves the train_config and FSDP config into a train_params.yaml.
|
|
|
This will be used by converter script in the inference folder to fetch the HF model name or path.
|
|
@@ -363,4 +362,6 @@ def save_train_params(train_config, fsdp_config):
|
|
|
else:
|
|
|
# Write the YAML string to the file
|
|
|
with open(file_name, 'w') as f:
|
|
|
- f.write(config_yaml)
|
|
|
+ f.write(config_yaml)
|
|
|
+ if rank==0:
|
|
|
+ print(f"training params are saved in {file_name}")
|