Hamid Shojanazeri 1 rok temu
rodzic
commit
ce966a97e0
1 zmienionych plików z 3 dodań i 7 usunięć
  1. 3 7
      examples/code_llama/code_instruct_example.py

+ 3 - 7
examples/code_llama/code_instruct_example.py

@@ -1,8 +1,6 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
-
 import fire
 import fire
 import os
 import os
 import sys
 import sys
@@ -60,7 +58,6 @@ def main(
     peft_model: str=None,
     peft_model: str=None,
     quantization: bool=False,
     quantization: bool=False,
     max_new_tokens =100, #The maximum numbers of tokens to generate
     max_new_tokens =100, #The maximum numbers of tokens to generate
-    prompt_file: str=None,
     seed: int=42, #seed value for reproducibility
     seed: int=42, #seed value for reproducibility
     do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
     do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
     min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens
     min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens
@@ -100,11 +97,10 @@ def main(
         based on the hardware being used. This would speed up inference when used for batched inputs.
         based on the hardware being used. This would speed up inference when used for batched inputs.
         """
         """
         try:
         try:
-            from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model)    
+            model.to_bettertransformer()   
         except ImportError:
         except ImportError:
-            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
-
+            print("Please check the Transformers version that support to_bettertransformer natively.")
+        
     tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer = AutoTokenizer.from_pretrained(model_name)
     safety_checker = get_safety_checker(enable_azure_content_safety,
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_sensitive_topics,