|
@@ -2,72 +2,49 @@
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import os
|
|
|
-import sys
|
|
|
-from typing import List, Union
|
|
|
|
|
|
import fire
|
|
|
import torch
|
|
|
-import transformers
|
|
|
-from datasets import load_dataset
|
|
|
-import os.path as osp
|
|
|
-from tqdm import tqdm
|
|
|
-
|
|
|
-# Unused imports removed
|
|
|
-from utils import fsdp_auto_wrap_policy
|
|
|
+import torch.distributed as dist
|
|
|
+import torch.distributed as dist
|
|
|
+import torch.optim as optim
|
|
|
+from peft import get_peft_model, prepare_model_for_int8_training
|
|
|
+from pkg_resources import packaging
|
|
|
+from torch.distributed.fsdp import (
|
|
|
+ FullyShardedDataParallel as FSDP,
|
|
|
+)
|
|
|
+from torch.optim.lr_scheduler import StepLR
|
|
|
+from torch.utils.data import DistributedSampler
|
|
|
from transformers import (
|
|
|
LlamaForCausalLM,
|
|
|
LlamaTokenizer,
|
|
|
LlamaConfig,
|
|
|
- AutoModelForCausalLM,
|
|
|
- AutoModelForSeq2SeqLM,
|
|
|
- AutoTokenizer,
|
|
|
default_data_collator,
|
|
|
- BitsAndBytesConfig
|
|
|
)
|
|
|
-import torch.distributed as dist
|
|
|
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
+
|
|
|
+import policies
|
|
|
+from configs import fsdp_config, train_config
|
|
|
+from policies import AnyPrecisionAdamW
|
|
|
+
|
|
|
+from utils import fsdp_auto_wrap_policy
|
|
|
+from utils.config_utils import (
|
|
|
+ update_config,
|
|
|
+ generate_peft_config,
|
|
|
+ generate_dataset_config,
|
|
|
+)
|
|
|
+from utils.dataset_utils import get_preprocessed_dataset
|
|
|
|
|
|
-# Unused imports removed
|
|
|
from utils.train_utils import (
|
|
|
- set_tokenizer_params,
|
|
|
train,
|
|
|
- evaluation,
|
|
|
freeze_transformer_layers,
|
|
|
- check_frozen_layers_peft_model,
|
|
|
setup,
|
|
|
setup_environ_flags,
|
|
|
- cleanup,
|
|
|
clear_gpu_cache,
|
|
|
- get_parameter_dtypes,
|
|
|
print_model_size,
|
|
|
get_policies
|
|
|
)
|
|
|
|
|
|
-from utils.dataset_utils import get_preprocessed_dataset
|
|
|
-
|
|
|
-from utils.config_utils import (
|
|
|
- update_config,
|
|
|
- generate_peft_config,
|
|
|
- generate_dataset_config,
|
|
|
-)
|
|
|
-from peft import get_peft_model, TaskType, prepare_model_for_int8_training
|
|
|
-import configs
|
|
|
-from torch.distributed.fsdp import (
|
|
|
- FullyShardedDataParallel as FSDP,
|
|
|
- MixedPrecision,
|
|
|
-)
|
|
|
-from torch.utils.data import DistributedSampler
|
|
|
-import policies
|
|
|
-from policies import AnyPrecisionAdamW
|
|
|
-from configs import fsdp_config, train_config
|
|
|
-import torch.optim as optim
|
|
|
-from torch.optim.lr_scheduler import StepLR
|
|
|
-from pkg_resources import packaging
|
|
|
-import torch
|
|
|
-import torch.nn as nn
|
|
|
-import torch.cuda.nccl as nccl
|
|
|
-import torch.distributed as dist
|
|
|
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
-
|
|
|
|
|
|
def main(**kwargs):
|
|
|
# Update the configuration for the training and sharding process
|