|
@@ -11,8 +11,17 @@ from typing import List
|
|
|
|
|
|
from transformers import LlamaTokenizer
|
|
|
from safety_utils import get_safety_checker
|
|
|
-from model_utils import load_model, load_peft_model
|
|
|
+from model_utils import load_model, load_peft_model, load_llama_from_config
|
|
|
|
|
|
+# Get the current file's directory
|
|
|
+current_directory = os.path.dirname(os.path.abspath(__file__))
|
|
|
+
|
|
|
+# Get the parent directory
|
|
|
+parent_directory = os.path.dirname(current_directory)
|
|
|
+
|
|
|
+# Append the parent directory to sys.path
|
|
|
+sys.path.append(parent_directory)
|
|
|
+from model_checkpointing import load_sharded_model_single_gpu
|
|
|
|
|
|
def main(
|
|
|
model_name,
|
|
@@ -49,7 +58,11 @@ def main(
|
|
|
# Set the seeds for reproducibility
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
|
- model = load_model(model_name, quantization)
|
|
|
+ # model = load_model(model_name, quantization)
|
|
|
+ model = load_llama_from_config()
|
|
|
+ loaded_model = load_sharded_model_single_gpu(model, model_name)
|
|
|
+
|
|
|
+ print("model has been loaded *******************")
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
|
|
tokenizer.add_special_tokens(
|