|
@@ -17,6 +17,7 @@ from utils import fsdp_auto_wrap_policy
|
|
|
from transformers import (
|
|
|
LlamaForCausalLM,
|
|
|
LlamaTokenizer,
|
|
|
+ LlamaConfig,
|
|
|
AutoModelForCausalLM,
|
|
|
AutoModelForSeq2SeqLM,
|
|
|
AutoTokenizer,
|
|
@@ -41,6 +42,8 @@ from utils.train_utils import (
|
|
|
get_policies
|
|
|
)
|
|
|
|
|
|
+from accelerate import init_empty_weights
|
|
|
+
|
|
|
from utils.dataset_utils import get_preprocessed_dataset
|
|
|
|
|
|
from utils.config_utils import (
|
|
@@ -62,8 +65,10 @@ import torch.optim as optim
|
|
|
from torch.optim.lr_scheduler import StepLR
|
|
|
from pkg_resources import packaging
|
|
|
import torch
|
|
|
+import torch.nn as nn
|
|
|
import torch.cuda.nccl as nccl
|
|
|
import torch.distributed as dist
|
|
|
+from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
|
|
|
|
|
@@ -90,11 +95,26 @@ def main(**kwargs):
|
|
|
gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
- model = LlamaForCausalLM.from_pretrained(
|
|
|
- train_config.model_name,
|
|
|
- load_in_8bit=True if train_config.quantization else None,
|
|
|
- device_map="auto" if train_config.quantization else None,
|
|
|
- )
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ # for FSDP, we save cpu memory by loading pretrained model on rank0 only.
|
|
|
+ # this avoids cpu oom when loading large models like llama 70B, in which case
|
|
|
+ # model alone would consume 2+TB cpu mem (70 * 4 * 8)
|
|
|
+ if rank == 0:
|
|
|
+ model = LlamaForCausalLM.from_pretrained(
|
|
|
+ train_config.model_name,
|
|
|
+ load_in_8bit=True if train_config.quantization else None,
|
|
|
+ device_map="auto" if train_config.quantization else None,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
|
+ with init_empty_weights():
|
|
|
+ model = LlamaForCausalLM(llama_config)
|
|
|
+ else:
|
|
|
+ model = LlamaForCausalLM.from_pretrained(
|
|
|
+ train_config.model_name,
|
|
|
+ load_in_8bit=True if train_config.quantization else None,
|
|
|
+ device_map="auto" if train_config.quantization else None,
|
|
|
+ )
|
|
|
|
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
|
|
|
@@ -127,7 +147,20 @@ def main(**kwargs):
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
|
-
|
|
|
+
|
|
|
+ # given the fast evolving PRs around meta device init, I am not sure
|
|
|
+ # what is the best param_init_fn atm, maybe we can switch to simple to_emtpy.
|
|
|
+ def _param_init_fn(module: nn.Module):
|
|
|
+ torch.manual_seed(0)
|
|
|
+ for submodule in module.modules():
|
|
|
+ for param_name, param in submodule.named_parameters(recurse=False):
|
|
|
+ if not _is_fsdp_flattened(param) and param.is_meta:
|
|
|
+ materialized_param = nn.Parameter(
|
|
|
+ torch.empty_like(param, device=torch.device("cuda"))
|
|
|
+ )
|
|
|
+ nn.init.uniform_(materialized_param)
|
|
|
+ setattr(submodule, param_name, materialized_param)
|
|
|
+
|
|
|
model = FSDP(
|
|
|
model,
|
|
|
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
|
|
@@ -135,6 +168,8 @@ def main(**kwargs):
|
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
limit_all_gathers=True,
|
|
|
+ sync_module_states=True,
|
|
|
+ param_init_fn=None if rank == 0 else _param_init_fn,
|
|
|
)
|
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
|
policies.apply_fsdp_checkpointing(model)
|