|
@@ -8,12 +8,11 @@ import torch
|
|
|
import os
|
|
|
import sys
|
|
|
import time
|
|
|
-from typing import List
|
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
-sys.path.append("..")
|
|
|
-from safety_utils import get_safety_checker
|
|
|
-from model_utils import load_model, load_peft_model, load_llama_from_config
|
|
|
+
|
|
|
+from ..safety_utils import get_safety_checker
|
|
|
+from ..model_utils import load_model, load_peft_model
|
|
|
|
|
|
def main(
|
|
|
model_name,
|