|
@@ -10,7 +10,7 @@ import sys
|
|
|
import time
|
|
|
from typing import List
|
|
|
|
|
|
-from transformers import CodeLlamaTokenizer
|
|
|
+from transformers import AutoTokenizer
|
|
|
sys.path.append("..")
|
|
|
from safety_utils import get_safety_checker
|
|
|
from model_utils import load_model, load_peft_model, load_llama_from_config
|
|
@@ -46,7 +46,6 @@ def main(
|
|
|
else:
|
|
|
print("No user prompt provided. Exiting.")
|
|
|
sys.exit(1)
|
|
|
-
|
|
|
# Set the seeds for reproducibility
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
@@ -70,7 +69,7 @@ def main(
|
|
|
except ImportError:
|
|
|
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
|
|
|
|
|
|
- tokenizer = CodeLlamaTokenizer.from_pretrained(model_name)
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
tokenizer.add_special_tokens(
|
|
|
{
|
|
|
|