|
@@ -62,7 +62,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
Returns: results dictionary containing average training and validation perplexity and loss
|
|
Returns: results dictionary containing average training and validation perplexity and loss
|
|
"""
|
|
"""
|
|
# Create a gradient scaler for fp16
|
|
# Create a gradient scaler for fp16
|
|
- if train_config.use_fp16 and train_config.enable_fsdp:
|
|
|
|
|
|
+ if train_config.use_fp16 and train_config.enable_fsdp:
|
|
scaler = ShardedGradScaler()
|
|
scaler = ShardedGradScaler()
|
|
elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
scaler = torch.cuda.amp.GradScaler()
|