Преглед изворни кода

switch to simpler param_init_fn and meta device init

lchu пре 1 година
родитељ
комит
1e64fc98d9
1 измењених фајлова са 2 додато и 17 уклоњено
  1. 2 17
      llama_finetuning.py

+ 2 - 17
llama_finetuning.py

@@ -42,8 +42,6 @@ from utils.train_utils import (
     get_policies  
 )
 
-from accelerate import init_empty_weights
-
 from utils.dataset_utils import get_preprocessed_dataset
 
 from utils.config_utils import (
@@ -107,7 +105,7 @@ def main(**kwargs):
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
-            with init_empty_weights():
+            with torch.device("meta"):
                 model = LlamaForCausalLM(llama_config)
     else:
         model = LlamaForCausalLM.from_pretrained(
@@ -148,19 +146,6 @@ def main(**kwargs):
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
 
-        # given the fast evolving PRs around meta device init, I am not sure
-        # what is the best param_init_fn atm, maybe we can switch to simple to_emtpy.
-        def _param_init_fn(module: nn.Module):
-            torch.manual_seed(0)
-            for submodule in module.modules():
-                for param_name, param in submodule.named_parameters(recurse=False):
-                    if not _is_fsdp_flattened(param) and param.is_meta:
-                        materialized_param = nn.Parameter(
-                            torch.empty_like(param, device=torch.device("cuda"))
-                        )
-                        nn.init.uniform_(materialized_param)
-                        setattr(submodule, param_name, materialized_param)
-
         model = FSDP(
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
@@ -169,7 +154,7 @@ def main(**kwargs):
             device_id=torch.cuda.current_device(),
             limit_all_gathers=True,
             sync_module_states=True,
-            param_init_fn=None if rank == 0 else _param_init_fn,
+            param_init_fn=None if rank == 0 else lambda module: module.to_empty(device=torch.device("cuda"), recurse=False),
         )
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)