|
@@ -2,40 +2,25 @@
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import os
|
|
|
-import sys
|
|
|
-from typing import List
|
|
|
-import yaml
|
|
|
import time
|
|
|
+import yaml
|
|
|
+from pathlib import Path
|
|
|
+from pkg_resources import packaging
|
|
|
+
|
|
|
|
|
|
-import fire
|
|
|
import torch
|
|
|
-import transformers
|
|
|
-from datasets import load_dataset
|
|
|
-from tqdm import tqdm
|
|
|
-"""
|
|
|
-Unused imports:
|
|
|
-import torch.nn as nn
|
|
|
-import bitsandbytes as bnb
|
|
|
-"""
|
|
|
-from torch.nn import functional as F
|
|
|
-from peft import (
|
|
|
- LoraConfig,
|
|
|
- get_peft_model,
|
|
|
- get_peft_model_state_dict,
|
|
|
- prepare_model_for_int8_training,
|
|
|
- set_peft_model_state_dict,
|
|
|
-)
|
|
|
-from transformers import LlamaForCausalLM, LlamaTokenizer
|
|
|
-from torch.distributed.fsdp import StateDictType
|
|
|
-import torch.distributed as dist
|
|
|
-from pkg_resources import packaging
|
|
|
-from .memory_utils import MemoryTrace
|
|
|
-import model_checkpointing
|
|
|
import torch.cuda.nccl as nccl
|
|
|
+import torch.distributed as dist
|
|
|
+from torch.distributed.fsdp import StateDictType
|
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
|
-from pathlib import Path
|
|
|
-sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
|
-from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
+from tqdm import tqdm
|
|
|
+from transformers import LlamaTokenizer
|
|
|
+
|
|
|
+
|
|
|
+from .memory_utils import MemoryTrace
|
|
|
+from ..model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
|
|
|
+from ..policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
+
|
|
|
|
|
|
def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
tokenizer.pad_token_id = 0
|
|
@@ -162,21 +147,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
else:
|
|
|
if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
|
|
|
|
|
|
- model_checkpointing.save_model_checkpoint(
|
|
|
+ save_model_checkpoint(
|
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
|
)
|
|
|
elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
|
|
|
print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
|
|
|
print("=====================================================")
|
|
|
|
|
|
- model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
+ save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
if train_config.save_optimizer:
|
|
|
- model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
+ save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
|
|
|
print("=====================================================")
|
|
|
|
|
|
if not train_config.use_peft and train_config.save_optimizer:
|
|
|
- model_checkpointing.save_optimizer_checkpoint(
|
|
|
+ save_optimizer_checkpoint(
|
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
|
)
|
|
|
print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
|