소스 검색

replace init_empty_weights with torch.device(meta)

lchu 1 년 전
부모
커밋
c8d4f38d23
1개의 변경된 파일1개의 추가작업 그리고 3개의 파일을 삭제
  1. 1 3
      llama_finetuning.py

+ 1 - 3
llama_finetuning.py

@@ -42,8 +42,6 @@ from utils.train_utils import (
     get_policies  
 )
 
-from accelerate import init_empty_weights
-
 from utils.dataset_utils import get_preprocessed_dataset
 
 from utils.config_utils import (
@@ -107,7 +105,7 @@ def main(**kwargs):
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
-            with init_empty_weights():
+            with torch.device("meta"):
                 model = LlamaForCausalLM(llama_config)
     else:
         model = LlamaForCausalLM.from_pretrained(