|
@@ -12,7 +12,7 @@ from typing import List
|
|
|
from transformers import LlamaTokenizer
|
|
|
from safety_utils import get_safety_checker
|
|
|
from model_utils import load_model, load_peft_model, load_llama_from_config
|
|
|
-
|
|
|
+from accelerate import init_empty_weights
|
|
|
# Get the current file's directory
|
|
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
@@ -59,9 +59,12 @@ def main(
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
|
# model = load_model(model_name, quantization)
|
|
|
- model_config = load_llama_from_config()
|
|
|
- model = load_sharded_model_single_gpu(model_config, model_name)
|
|
|
-
|
|
|
+ model_def = load_llama_from_config()
|
|
|
+ # print(dir(model_def))
|
|
|
+ # model_def.eval()
|
|
|
+ model = load_sharded_model_single_gpu(model_def, model_name)
|
|
|
+ model.to(torch.bfloat16)
|
|
|
+ model.to("cuda:0")
|
|
|
print("model has been loaded *******************")
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained("../../../hf-llama-pr/7B/")
|