@@ -14,6 +14,7 @@ from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
from safety_utils import get_safety_checker
from model_utils import load_model, load_peft_model
from chat_utils import read_dialogs_from_file, format_tokens
+from accelerate.utils import is_xpu_available
def main(
model_name,