Copyright (c) Meta Platforms, Inc. and affiliates. This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
This notebook shows how to train a Llama 2 model on a single GPU (e.g. A10 with 24GB) using int8 quantization and LoRA.
The example uses the Hugging Face trainer and model which means that the checkpoint has to be converted from its original format into the dedicated Hugging Face format.
The conversion can be achieved by running the convert_llama_weights_to_hf.py
script provided with the transformer package.
Given that the original checkpoint resides under models/7B
we can install all requirements and convert the checkpoint with:
# %%bash
# pip install transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets
# TRANSFORM=`python -c "import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')"`
# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B
Point model_id to model weight folder
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
model_id="./models_hf/7B"
tokenizer = LlamaTokenizer.from_pretrained(model_id)
model =LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)
/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm ===================================BUG REPORT=================================== Welcome to bitsandbytes. For bug reports, please run python -m bitsandbytes and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues ================================================================================ bin /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda112.so CUDA SETUP: CUDA runtime path found: /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so CUDA SETUP: Highest compute capability among GPUs detected: 8.0 CUDA SETUP: Detected CUDA version 112 CUDA SETUP: Loading binary /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda112.so... /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: /data/home/mreso/miniconda3/envs/llama did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths... warn(msg) /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/local/cuda/efa/lib')} warn(msg) The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function. Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:10<00:00, 5.09s/it]
We load and preprocess the samsum dataset which consists of curated pairs of dialogs and their summarization:
from pathlib import Path
import os
import sys
from utils.dataset_utils import get_preprocessed_dataset
from configs.datasets import samsum_dataset
train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')
Found cached dataset samsum (/data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e) Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-b14554a76c1c7ecd.arrow Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-e40e61e15ebeb527.arrow Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-e08ac9e1b792e7ba.arrow
Run the base model on an example input:
eval_prompt = """
Summarize this dialog:
A: Hi Tom, are you busy tomorrow’s afternoon?
B: I’m pretty sure I am. What’s up?
A: Can you go with me to the animal shelter?.
B: What do you want to do?
A: I want to get a puppy for my son.
B: That will make him so happy.
A: Yeah, we’ve discussed it many times. I think he’s ready now.
B: That’s good. Raising a dog is a tough issue. Like having a baby ;-)
A: I'll get him one of those little dogs.
B: One that won't grow up too big;-)
A: And eat too much;-))
B: Do you know which one he would like?
A: Oh, yes, I took him there last Monday. He showed me one that he really liked.
B: I bet you had to drag him away.
A: He wanted to take it home right away ;-).
B: I wonder what he'll name it.
A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))
---
Summary:
"""
model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
model.eval()
with torch.no_grad():
print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))
Summarize this dialog: A: Hi Tom, are you busy tomorrow’s afternoon? B: I’m pretty sure I am. What’s up? A: Can you go with me to the animal shelter?. B: What do you want to do? A: I want to get a puppy for my son. B: That will make him so happy. A: Yeah, we’ve discussed it many times. I think he’s ready now. B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) A: I'll get him one of those little dogs. B: One that won't grow up too big;-) A: And eat too much;-)) B: Do you know which one he would like? A: Oh, yes, I took him there last Monday. He showed me one that he really liked. B: I bet you had to drag him away. A: He wanted to take it home right away ;-). B: I wonder what he'll name it. A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-))) --- Summary: A: Hi Tom, are you busy tomorrow’s afternoon? B: I’m pretty sure I am. What’s up? A: Can you go with me to the animal shelter?. B: What do you want to do? A: I want to get a puppy for my son. B: That will make him so happy. A: Yeah, we’ve discussed it many times. I think he’s ready now. B
We can see that the base model only repeats the conversation.
Let's prepare the model for Parameter Efficient Fine Tuning (PEFT):
model.train()
def create_peft_config(model):
from peft import (
get_peft_model,
LoraConfig,
TaskType,
prepare_model_for_int8_training,
)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.05,
target_modules = ["q_proj", "v_proj"]
)
# prepare int-8 model for training
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config
# create peft config
model, lora_config = create_peft_config(model)
trainable params: 4194304 || all params: 6742609920 || trainable%: 0.06220594176090199
from transformers import TrainerCallback
from contextlib import nullcontext
enable_profiler = False
output_dir = "tmp/llama-output"
config = {
'lora_config': lora_config,
'learning_rate': 1e-4,
'num_train_epochs': 1,
'gradient_accumulation_steps': 2,
'per_device_train_batch_size': 2,
'gradient_checkpointing': False,
}
# Set up profiler
if enable_profiler:
wait, warmup, active, repeat = 1, 1, 2, 1
total_steps = (wait + warmup + active) * (1 + repeat)
schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)
profiler = torch.profiler.profile(
schedule=schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler(f"{output_dir}/logs/tensorboard"),
record_shapes=True,
profile_memory=True,
with_stack=True)
class ProfilerCallback(TrainerCallback):
def __init__(self, profiler):
self.profiler = profiler
def on_step_end(self, *args, **kwargs):
self.profiler.step()
profiler_callback = ProfilerCallback(profiler)
else:
profiler = nullcontext()
Here, we fine tune the model for a single epoch which takes a bit more than an hour on a A100.
from transformers import default_data_collator, Trainer, TrainingArguments
# Define training args
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
bf16=True, # Use BF16 if available
# logging strategies
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
logging_steps=10,
save_strategy="no",
optim="adamw_torch_fused",
max_steps=total_steps if enable_profiler else -1,
**{k:v for k,v in config.items() if k != 'lora_config'}
)
with profiler:
# Create Trainer instance
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=default_data_collator,
callbacks=[profiler_callback] if enable_profiler else [],
)
# Start training
trainer.train()
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:321: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:321: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
Step | Training Loss |
---|---|
10 | 1.965000 |
20 | 1.845600 |
30 | 1.801100 |
40 | 1.780900 |
50 | 1.715400 |
60 | 1.697800 |
70 | 1.707600 |
80 | 1.713300 |
90 | 1.663900 |
100 | 1.702700 |
110 | 1.658800 |
120 | 1.692400 |
130 | 1.644900 |
140 | 1.687900 |
150 | 1.686600 |
160 | 1.649600 |
170 | 1.666900 |
180 | 1.709200 |
190 | 1.670400 |
200 | 1.662700 |
210 | 1.681300 |
220 | 1.685500 |
230 | 1.663400 |
240 | 1.638300 |
250 | 1.627400 |
260 | 1.654300 |
270 | 1.640900 |
280 | 1.674700 |
290 | 1.657300 |
300 | 1.660200 |
310 | 1.666600 |
320 | 1.674500 |
330 | 1.656200 |
340 | 1.684300 |
350 | 1.667900 |
360 | 1.661400 |
370 | 1.676800 |
380 | 1.628100 |
Save model checkpoint
model.save_pretrained(output_dir)
Try the fine tuned model on the same example again to see the learning progress:
model.eval()
with torch.no_grad():
print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))
Summarize this dialog: A: Hi Tom, are you busy tomorrow’s afternoon? B: I’m pretty sure I am. What’s up? A: Can you go with me to the animal shelter?. B: What do you want to do? A: I want to get a puppy for my son. B: That will make him so happy. A: Yeah, we’ve discussed it many times. I think he’s ready now. B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) A: I'll get him one of those little dogs. B: One that won't grow up too big;-) A: And eat too much;-)) B: Do you know which one he would like? A: Oh, yes, I took him there last Monday. He showed me one that he really liked. B: I bet you had to drag him away. A: He wanted to take it home right away ;-). B: I wonder what he'll name it. A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-))) --- Summary: A wants to get a puppy for his son. He took him to the animal shelter last Monday. He showed him one that he really liked. A will name it after his dead hamster - Lemmy.