|
@@ -2,40 +2,26 @@
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import os
|
|
import os
|
|
-import sys
|
|
|
|
-from typing import List
|
|
|
|
-import yaml
|
|
|
|
import time
|
|
import time
|
|
|
|
+import yaml
|
|
|
|
+from pathlib import Path
|
|
|
|
+from pkg_resources import packaging
|
|
|
|
+
|
|
|
|
|
|
-import fire
|
|
|
|
import torch
|
|
import torch
|
|
-import transformers
|
|
|
|
-from datasets import load_dataset
|
|
|
|
-from tqdm import tqdm
|
|
|
|
-"""
|
|
|
|
-Unused imports:
|
|
|
|
-import torch.nn as nn
|
|
|
|
-import bitsandbytes as bnb
|
|
|
|
-"""
|
|
|
|
-from torch.nn import functional as F
|
|
|
|
-from peft import (
|
|
|
|
- LoraConfig,
|
|
|
|
- get_peft_model,
|
|
|
|
- get_peft_model_state_dict,
|
|
|
|
- prepare_model_for_int8_training,
|
|
|
|
- set_peft_model_state_dict,
|
|
|
|
-)
|
|
|
|
-from transformers import LlamaForCausalLM, LlamaTokenizer
|
|
|
|
-from torch.distributed.fsdp import StateDictType
|
|
|
|
-import torch.distributed as dist
|
|
|
|
-from pkg_resources import packaging
|
|
|
|
-from .memory_utils import MemoryTrace
|
|
|
|
-import model_checkpointing
|
|
|
|
import torch.cuda.nccl as nccl
|
|
import torch.cuda.nccl as nccl
|
|
|
|
+import torch.distributed as dist
|
|
|
|
+from torch.distributed.fsdp import StateDictType
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
-from pathlib import Path
|
|
|
|
-sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
|
|
-from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
|
|
|
+from tqdm import tqdm
|
|
|
|
+from transformers import LlamaTokenizer
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
|
|
|
|
+from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
|
+from llama_recipes.utils.memory_utils import MemoryTrace
|
|
|
|
+from accelerate.utils import is_xpu_available, is_ccl_available
|
|
|
|
+
|
|
from accelerate.utils import is_xpu_available, is_ccl_available
|
|
from accelerate.utils import is_xpu_available, is_ccl_available
|
|
|
|
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
@@ -84,7 +70,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
with MemoryTrace() as memtrace: # track the memory usage
|
|
with MemoryTrace() as memtrace: # track the memory usage
|
|
model.train()
|
|
model.train()
|
|
total_loss = 0.0
|
|
total_loss = 0.0
|
|
- for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
|
|
|
|
|
|
+ total_length = len(train_dataloader)//gradient_accumulation_steps
|
|
|
|
+ pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length)
|
|
|
|
+ for step, batch in enumerate(train_dataloader):
|
|
for key in batch.keys():
|
|
for key in batch.keys():
|
|
if train_config.enable_fsdp:
|
|
if train_config.enable_fsdp:
|
|
batch[key] = batch[key].to(local_rank)
|
|
batch[key] = batch[key].to(local_rank)
|
|
@@ -103,17 +91,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
scaler.step(optimizer)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
scaler.update()
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
|
|
+ pbar.update(step//gradient_accumulation_steps)
|
|
else:
|
|
else:
|
|
# regular backpropagation when fp16 is not used
|
|
# regular backpropagation when fp16 is not used
|
|
loss.backward()
|
|
loss.backward()
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
optimizer.step()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
- if train_config.enable_fsdp:
|
|
|
|
- if rank==0:
|
|
|
|
- print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
|
- else:
|
|
|
|
- print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
|
|
|
+ pbar.update(step//gradient_accumulation_steps)
|
|
|
|
+
|
|
|
|
+ pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
|
+
|
|
epoch_end_time = time.perf_counter()-epoch_start_time
|
|
epoch_end_time = time.perf_counter()-epoch_start_time
|
|
epoch_times.append(epoch_end_time)
|
|
epoch_times.append(epoch_end_time)
|
|
# Reducing total_loss across all devices if there's more than one CUDA device
|
|
# Reducing total_loss across all devices if there's more than one CUDA device
|
|
@@ -180,21 +168,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
else:
|
|
else:
|
|
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
|
|
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
|
|
|
|
|
|
- model_checkpointing.save_model_checkpoint(
|
|
|
|
|
|
+ save_model_checkpoint(
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
)
|
|
)
|
|
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
|
|
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
|
|
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
|
|
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
|
|
print("=====================================================")
|
|
print("=====================================================")
|
|
|
|
|
|
- model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
|
|
|
+ save_model_and_optimizer_sharded(model, rank, train_config)
|
|
if train_config.save_optimizer:
|
|
if train_config.save_optimizer:
|
|
- model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
|
|
|
+ save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
|
|
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
|
|
print("=====================================================")
|
|
print("=====================================================")
|
|
|
|
|
|
if not train_config.use_peft and train_config.save_optimizer:
|
|
if not train_config.use_peft and train_config.save_optimizer:
|
|
- model_checkpointing.save_optimizer_checkpoint(
|
|
|
|
|
|
+ save_optimizer_checkpoint(
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
)
|
|
)
|
|
print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
|
|
print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
|
|
@@ -212,14 +200,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
val_loss.append(best_val_loss)
|
|
val_loss.append(best_val_loss)
|
|
val_prep.append(eval_ppl)
|
|
val_prep.append(eval_ppl)
|
|
-
|
|
|
|
if train_config.enable_fsdp:
|
|
if train_config.enable_fsdp:
|
|
if rank==0:
|
|
if rank==0:
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
else:
|
|
else:
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
- avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
|
- avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)
|
|
|
|
|
|
+ avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
|
+ avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
avg_train_loss = sum(train_loss)/len(train_loss)
|
|
avg_train_loss = sum(train_loss)/len(train_loss)
|
|
if train_config.run_validation:
|
|
if train_config.run_validation:
|