{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Copyright (c) Meta Platforms, Inc. and affiliates.\n", "This software may be used and distributed according to the terms of the Llama 2 Community License Agreement." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Quick Start Notebook\n", "\n", "This notebook shows how to train a Llama 2 model on a single GPU (e.g. A10 with 24GB) using int8 quantization and LoRA.\n", "\n", "### Step 0: Install pre-requirements and convert checkpoint\n", "\n", "The example uses the Hugging Face trainer and model which means that the checkpoint has to be converted from its original format into the dedicated Hugging Face format.\n", "The conversion can be achieved by running the `convert_llama_weights_to_hf.py` script provided with the transformer package.\n", "Given that the original checkpoint resides under `models/7B` we can install all requirements and convert the checkpoint with:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# %%bash\n", "# pip install llama-recipes transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n", "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n", "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 1: Load the model\n", "\n", "Point model_id to model weight folder" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please run\n", "\n", "python -m bitsandbytes\n", "\n", " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "bin /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda112.so\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 8.0\n", "CUDA SETUP: Detected CUDA version 112\n", "CUDA SETUP: Loading binary /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda112.so...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: /data/home/mreso/miniconda3/envs/llama did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths...\n", " warn(msg)\n", "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/local/cuda/efa/lib')}\n", " warn(msg)\n", "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n", "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:10<00:00, 5.09s/it]\n" ] } ], "source": [ "import torch\n", "from transformers import LlamaForCausalLM, LlamaTokenizer\n", "\n", "model_id=\"./models_hf/7B\"\n", "\n", "tokenizer = LlamaTokenizer.from_pretrained(model_id)\n", "\n", "model =LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 2: Load the preprocessed dataset\n", "\n", "We load and preprocess the samsum dataset which consists of curated pairs of dialogs and their summarization:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset samsum (/data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)\n", "Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-b14554a76c1c7ecd.arrow\n", "Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-e40e61e15ebeb527.arrow\n", "Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-e08ac9e1b792e7ba.arrow\n" ] } ], "source": [ "from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n", "from llama_recipes.configs.datasets import samsum_dataset\n", "\n", "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 3: Check base model\n", "\n", "Run the base model on an example input:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Summarize this dialog:\n", "A: Hi Tom, are you busy tomorrow’s afternoon?\n", "B: I’m pretty sure I am. What’s up?\n", "A: Can you go with me to the animal shelter?.\n", "B: What do you want to do?\n", "A: I want to get a puppy for my son.\n", "B: That will make him so happy.\n", "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n", "A: I'll get him one of those little dogs.\n", "B: One that won't grow up too big;-)\n", "A: And eat too much;-))\n", "B: Do you know which one he would like?\n", "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n", "B: I bet you had to drag him away.\n", "A: He wanted to take it home right away ;-).\n", "B: I wonder what he'll name it.\n", "A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))\n", "---\n", "Summary:\n", "A: Hi Tom, are you busy tomorrow’s afternoon?\n", "B: I’m pretty sure I am. What’s up?\n", "A: Can you go with me to the animal shelter?.\n", "B: What do you want to do?\n", "A: I want to get a puppy for my son.\n", "B: That will make him so happy.\n", "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", "B\n" ] } ], "source": [ "eval_prompt = \"\"\"\n", "Summarize this dialog:\n", "A: Hi Tom, are you busy tomorrow’s afternoon?\n", "B: I’m pretty sure I am. What’s up?\n", "A: Can you go with me to the animal shelter?.\n", "B: What do you want to do?\n", "A: I want to get a puppy for my son.\n", "B: That will make him so happy.\n", "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n", "A: I'll get him one of those little dogs.\n", "B: One that won't grow up too big;-)\n", "A: And eat too much;-))\n", "B: Do you know which one he would like?\n", "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n", "B: I bet you had to drag him away.\n", "A: He wanted to take it home right away ;-).\n", "B: I wonder what he'll name it.\n", "A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))\n", "---\n", "Summary:\n", "\"\"\"\n", "\n", "model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n", "\n", "model.eval()\n", "with torch.no_grad():\n", " print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can see that the base model only repeats the conversation.\n", "\n", "### Step 4: Prepare model for PEFT\n", "\n", "Let's prepare the model for Parameter Efficient Fine Tuning (PEFT):" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 4194304 || all params: 6742609920 || trainable%: 0.06220594176090199\n" ] } ], "source": [ "model.train()\n", "\n", "def create_peft_config(model):\n", " from peft import (\n", " get_peft_model,\n", " LoraConfig,\n", " TaskType,\n", " prepare_model_for_kbit_training,\n", " )\n", "\n", " peft_config = LoraConfig(\n", " task_type=TaskType.CAUSAL_LM,\n", " inference_mode=False,\n", " r=8,\n", " lora_alpha=32,\n", " lora_dropout=0.05,\n", " target_modules = [\"q_proj\", \"v_proj\"]\n", " )\n", "\n", " # prepare int-8 model for training\n", " model = prepare_model_for_kbit_training(model)\n", " model = get_peft_model(model, peft_config)\n", " model.print_trainable_parameters()\n", " return model, peft_config\n", "\n", "# create peft config\n", "model, lora_config = create_peft_config(model)\n", "\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "### Step 5: Define an optional profiler" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from transformers import TrainerCallback\n", "from contextlib import nullcontext\n", "enable_profiler = False\n", "output_dir = \"tmp/llama-output\"\n", "\n", "config = {\n", " 'lora_config': lora_config,\n", " 'learning_rate': 1e-4,\n", " 'num_train_epochs': 1,\n", " 'gradient_accumulation_steps': 2,\n", " 'per_device_train_batch_size': 2,\n", " 'gradient_checkpointing': False,\n", "}\n", "\n", "# Set up profiler\n", "if enable_profiler:\n", " wait, warmup, active, repeat = 1, 1, 2, 1\n", " total_steps = (wait + warmup + active) * (1 + repeat)\n", " schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)\n", " profiler = torch.profiler.profile(\n", " schedule=schedule,\n", " on_trace_ready=torch.profiler.tensorboard_trace_handler(f\"{output_dir}/logs/tensorboard\"),\n", " record_shapes=True,\n", " profile_memory=True,\n", " with_stack=True)\n", " \n", " class ProfilerCallback(TrainerCallback):\n", " def __init__(self, profiler):\n", " self.profiler = profiler\n", " \n", " def on_step_end(self, *args, **kwargs):\n", " self.profiler.step()\n", "\n", " profiler_callback = ProfilerCallback(profiler)\n", "else:\n", " profiler = nullcontext()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 6: Fine tune the model\n", "\n", "Here, we fine tune the model for a single epoch which takes a bit more than an hour on a A100." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n", "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:321: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:321: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
10 | \n", "1.965000 | \n", "
20 | \n", "1.845600 | \n", "
30 | \n", "1.801100 | \n", "
40 | \n", "1.780900 | \n", "
50 | \n", "1.715400 | \n", "
60 | \n", "1.697800 | \n", "
70 | \n", "1.707600 | \n", "
80 | \n", "1.713300 | \n", "
90 | \n", "1.663900 | \n", "
100 | \n", "1.702700 | \n", "
110 | \n", "1.658800 | \n", "
120 | \n", "1.692400 | \n", "
130 | \n", "1.644900 | \n", "
140 | \n", "1.687900 | \n", "
150 | \n", "1.686600 | \n", "
160 | \n", "1.649600 | \n", "
170 | \n", "1.666900 | \n", "
180 | \n", "1.709200 | \n", "
190 | \n", "1.670400 | \n", "
200 | \n", "1.662700 | \n", "
210 | \n", "1.681300 | \n", "
220 | \n", "1.685500 | \n", "
230 | \n", "1.663400 | \n", "
240 | \n", "1.638300 | \n", "
250 | \n", "1.627400 | \n", "
260 | \n", "1.654300 | \n", "
270 | \n", "1.640900 | \n", "
280 | \n", "1.674700 | \n", "
290 | \n", "1.657300 | \n", "
300 | \n", "1.660200 | \n", "
310 | \n", "1.666600 | \n", "
320 | \n", "1.674500 | \n", "
330 | \n", "1.656200 | \n", "
340 | \n", "1.684300 | \n", "
350 | \n", "1.667900 | \n", "
360 | \n", "1.661400 | \n", "
370 | \n", "1.676800 | \n", "
380 | \n", "1.628100 | \n", "
"
],
"text/plain": [
"