ソースを参照

replace init_empty_weights with torch.device(meta)

lchu 1 年間 前
コミット
c8d4f38d23
1 ファイル変更1 行追加3 行削除
  1. 1 3
      llama_finetuning.py

+ 1 - 3
llama_finetuning.py

@@ -42,8 +42,6 @@ from utils.train_utils import (
     get_policies  
 )
 
-from accelerate import init_empty_weights
-
 from utils.dataset_utils import get_preprocessed_dataset
 
 from utils.config_utils import (
@@ -107,7 +105,7 @@ def main(**kwargs):
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
-            with init_empty_weights():
+            with torch.device("meta"):
                 model = LlamaForCausalLM(llama_config)
     else:
         model = LlamaForCausalLM.from_pretrained(