Browse Source

adding inference updates

Hamid Shojanazeri 1 year ago
parent
commit
7ad5bd5ef2
1 changed files with 7 additions and 4 deletions
  1. 7 4
      inference/inference.py

+ 7 - 4
inference/inference.py

@@ -12,7 +12,7 @@ from typing import List
 from transformers import LlamaTokenizer
 from safety_utils import get_safety_checker
 from model_utils import load_model, load_peft_model, load_llama_from_config
-
+from accelerate import init_empty_weights
 # Get the current file's directory
 current_directory = os.path.dirname(os.path.abspath(__file__))
 
@@ -59,9 +59,12 @@ def main(
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     # model = load_model(model_name, quantization)
-    model_config = load_llama_from_config()
-    model = load_sharded_model_single_gpu(model_config, model_name)
-    
+    model_def = load_llama_from_config()
+    # print(dir(model_def))
+    # model_def.eval()
+    model = load_sharded_model_single_gpu(model_def, model_name)
+    model.to(torch.bfloat16)
+    model.to("cuda:0")
     print("model has been loaded *******************")
 
     tokenizer = LlamaTokenizer.from_pretrained("../../../hf-llama-pr/7B/")