|
@@ -42,8 +42,6 @@ from utils.train_utils import (
|
|
|
get_policies
|
|
|
)
|
|
|
|
|
|
-from accelerate import init_empty_weights
|
|
|
-
|
|
|
from utils.dataset_utils import get_preprocessed_dataset
|
|
|
|
|
|
from utils.config_utils import (
|
|
@@ -107,7 +105,7 @@ def main(**kwargs):
|
|
|
)
|
|
|
else:
|
|
|
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
|
- with init_empty_weights():
|
|
|
+ with torch.device("meta"):
|
|
|
model = LlamaForCausalLM(llama_config)
|
|
|
else:
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
@@ -148,19 +146,6 @@ def main(**kwargs):
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
|
|
|
|
- # given the fast evolving PRs around meta device init, I am not sure
|
|
|
- # what is the best param_init_fn atm, maybe we can switch to simple to_emtpy.
|
|
|
- def _param_init_fn(module: nn.Module):
|
|
|
- torch.manual_seed(0)
|
|
|
- for submodule in module.modules():
|
|
|
- for param_name, param in submodule.named_parameters(recurse=False):
|
|
|
- if not _is_fsdp_flattened(param) and param.is_meta:
|
|
|
- materialized_param = nn.Parameter(
|
|
|
- torch.empty_like(param, device=torch.device("cuda"))
|
|
|
- )
|
|
|
- nn.init.uniform_(materialized_param)
|
|
|
- setattr(submodule, param_name, materialized_param)
|
|
|
-
|
|
|
model = FSDP(
|
|
|
model,
|
|
|
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
|
|
@@ -169,7 +154,7 @@ def main(**kwargs):
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
limit_all_gathers=True,
|
|
|
sync_module_states=True,
|
|
|
- param_init_fn=None if rank == 0 else _param_init_fn,
|
|
|
+ param_init_fn=None if rank == 0 else lambda module: module.to_empty(device=torch.device("cuda"), recurse=False),
|
|
|
)
|
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
|
policies.apply_fsdp_checkpointing(model)
|