inference.md 3.6 KB

Inference

For inference we have provided an inference script. Depending on the type of finetuning performed during training the inference script takes different arguments. To finetune all model parameters the output dir of the training has to be given as --model_name argument. In the case of a parameter efficient method like lora the base model has to be given as --model_name and the output dir of the training has to be given as --peft_model argument. Additionally, a prompt for the model in the form of a text file has to be provided. The prompt file can either be piped through standard input or given as --prompt_file parameter.

Content Safety The inference script also supports safety checks for both user prompt and model outputs. In particular, we use two packages, AuditNLG and Azure content safety.

Note If using Azure content Safety, please make sure to get the endpoint and API key as described here and add them as the following environment variables,CONTENT_SAFETY_ENDPOINT and CONTENT_SAFETY_KEY.

Examples:

# Full finetuning of all parameters
cat <test_prompt_file> | python inference/inference.py --model_name <training_config.output_dir> --use_auditnlg
# PEFT method
cat <test_prompt_file> | python inference/inference.py --model_name <training_config.model_name> --peft_model <training_config.output_dir> --use_auditnlg
# prompt as parameter
python inference/inference.py --model_name <training_config.output_dir> --prompt_file <test_prompt_file> --use_auditnlg

The inference folder contains test prompts for summarization use-case:

inference/samsum_prompt.txt
...

Chat completion The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:

python chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chats.json  --quantization --use_auditnlg

Loading back FSDP checkpoints

In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown here, you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above. To convert the checkpoint use the following command:

 python checkpoint_converter_fsdp_hf.py --model_name  PATH/to/FSDP/Checkpoints --save_dir PATH/to/save/checkpoints --model_path PATH/or/HF/model_name

 # --model_path specifies the HF Llama model name or path where it has config.json and tokenizer.json

Other Inference Options

Alternate inference options include:

vLLM: To use vLLM you will need to install it using the instructions here. Once installed, you can use the vLLM_ineference.py script provided here.

Below is an example of how to run the vLLM_inference.py script found within the inference folder.

python vLLM_inference.py --model_name <PATH/TO/MODEL/7B>

TGI: Text Generation Inference (TGI) is another inference option available to you. For more information on how to set up and use TGI see here.