Browse Source

add native BT

Hamid Shojanazeri 1 year ago
parent
commit
ce966a97e0
1 changed files with 3 additions and 7 deletions
  1. 3 7
      examples/code_llama/code_instruct_example.py

+ 3 - 7
examples/code_llama/code_instruct_example.py

@@ -1,8 +1,6 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
-
 import fire
 import os
 import sys
@@ -60,7 +58,6 @@ def main(
     peft_model: str=None,
     quantization: bool=False,
     max_new_tokens =100, #The maximum numbers of tokens to generate
-    prompt_file: str=None,
     seed: int=42, #seed value for reproducibility
     do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
     min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens
@@ -100,11 +97,10 @@ def main(
         based on the hardware being used. This would speed up inference when used for batched inputs.
         """
         try:
-            from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model)    
+            model.to_bettertransformer()   
         except ImportError:
-            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
-
+            print("Please check the Transformers version that support to_bettertransformer natively.")
+        
     tokenizer = AutoTokenizer.from_pretrained(model_name)
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,