瀏覽代碼

Revert "replace init_empty_weights with torch.device(meta)"

This reverts commit c8d4f38d2330e14288b5dd882d0a275d01daa86c.
lchu 1 年之前
父節點
當前提交
101391f46a
共有 1 個文件被更改,包括 3 次插入1 次删除
  1. 3 1
      llama_finetuning.py

+ 3 - 1
llama_finetuning.py

@@ -42,6 +42,8 @@ from utils.train_utils import (
     get_policies  
 )
 
+from accelerate import init_empty_weights
+
 from utils.dataset_utils import get_preprocessed_dataset
 
 from utils.config_utils import (
@@ -105,7 +107,7 @@ def main(**kwargs):
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
-            with torch.device("meta"):
+            with init_empty_weights():
                 model = LlamaForCausalLM(llama_config)
     else:
         model = LlamaForCausalLM.from_pretrained(