|
@@ -35,11 +35,6 @@ from pathlib import Path
|
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
|
from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
|
|
|
-scaler = ShardedGradScaler()
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
tokenizer.pad_token_id = 0
|
|
|
tokenizer.padding_side = "left"
|
|
@@ -67,8 +62,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
Returns: results dictionary containing average training and validation perplexity and loss
|
|
|
"""
|
|
|
# Create a gradient scaler for fp16
|
|
|
- scaler = torch.cuda.amp.GradScaler() if train_config.use_fp16 else None
|
|
|
-
|
|
|
+ if train_config.use_fp16 and train_config.enable_fsdp:
|
|
|
+ scaler = ShardedGradScaler()
|
|
|
+ elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
|
+ scaler = torch.cuda.amp.GradScaler()
|
|
|
+
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|