Parcourir la source

Revert "replace init_empty_weights with torch.device(meta)"

This reverts commit c8d4f38d2330e14288b5dd882d0a275d01daa86c.
lchu il y a 1 an
Parent
commit
101391f46a
1 fichiers modifiés avec 3 ajouts et 1 suppressions
  1. 3 1
      llama_finetuning.py

+ 3 - 1
llama_finetuning.py

@@ -42,6 +42,8 @@ from utils.train_utils import (
     get_policies  
 )
 
+from accelerate import init_empty_weights
+
 from utils.dataset_utils import get_preprocessed_dataset
 
 from utils.config_utils import (
@@ -105,7 +107,7 @@ def main(**kwargs):
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
-            with torch.device("meta"):
+            with init_empty_weights():
                 model = LlamaForCausalLM(llama_config)
     else:
         model = LlamaForCausalLM.from_pretrained(