|
@@ -42,6 +42,8 @@ from utils.train_utils import (
|
|
|
get_policies
|
|
|
)
|
|
|
|
|
|
+from accelerate import init_empty_weights
|
|
|
+
|
|
|
from utils.dataset_utils import get_preprocessed_dataset
|
|
|
|
|
|
from utils.config_utils import (
|
|
@@ -105,7 +107,7 @@ def main(**kwargs):
|
|
|
)
|
|
|
else:
|
|
|
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
|
- with torch.device("meta"):
|
|
|
+ with init_empty_weights():
|
|
|
model = LlamaForCausalLM(llama_config)
|
|
|
else:
|
|
|
model = LlamaForCausalLM.from_pretrained(
|