|
@@ -1,8 +1,6 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
-# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
|
-
|
|
|
import fire
|
|
|
import os
|
|
|
import sys
|
|
@@ -60,7 +58,6 @@ def main(
|
|
|
peft_model: str=None,
|
|
|
quantization: bool=False,
|
|
|
max_new_tokens =100, #The maximum numbers of tokens to generate
|
|
|
- prompt_file: str=None,
|
|
|
seed: int=42, #seed value for reproducibility
|
|
|
do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
|
|
|
min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens
|
|
@@ -100,11 +97,10 @@ def main(
|
|
|
based on the hardware being used. This would speed up inference when used for batched inputs.
|
|
|
"""
|
|
|
try:
|
|
|
- from optimum.bettertransformer import BetterTransformer
|
|
|
- model = BetterTransformer.transform(model)
|
|
|
+ model.to_bettertransformer()
|
|
|
except ImportError:
|
|
|
- print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
|
|
|
-
|
|
|
+ print("Please check the Transformers version that support to_bettertransformer natively.")
|
|
|
+
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
safety_checker = get_safety_checker(enable_azure_content_safety,
|
|
|
enable_sensitive_topics,
|