|
@@ -4,16 +4,16 @@
|
|
# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
|
|
|
|
import fire
|
|
import fire
|
|
-import torch
|
|
|
|
import os
|
|
import os
|
|
import sys
|
|
import sys
|
|
import time
|
|
import time
|
|
-from typing import List
|
|
|
|
|
|
|
|
|
|
+import torch
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
-sys.path.append("..")
|
|
|
|
-from safety_utils import get_safety_checker
|
|
|
|
-from model_utils import load_model, load_peft_model, load_llama_from_config
|
|
|
|
|
|
+
|
|
|
|
+from ..safety_utils import get_safety_checker
|
|
|
|
+from ..model_utils import load_model, load_peft_model
|
|
|
|
+
|
|
|
|
|
|
def main(
|
|
def main(
|
|
model_name,
|
|
model_name,
|