{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Copyright (c) Meta Platforms, Inc. and affiliates.\n", "This software may be used and distributed according to the terms of the Llama 2 Community License Agreement." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Quick Start Notebook\n", "\n", "This notebook shows how to train a Llama 2 model on a single GPU (e.g. A10 with 24GB) using int8 quantization and LoRA.\n", "\n", "### Step 0: Install pre-requirements and convert checkpoint\n", "\n", "The example uses the Hugging Face trainer and model which means that the checkpoint has to be converted from its original format into the dedicated Hugging Face format.\n", "The conversion can be achieved by running the `convert_llama_weights_to_hf.py` script provided with the transformer package.\n", "Given that the original checkpoint resides under `models/7B` we can install all requirements and convert the checkpoint with:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# %%bash\n", "# pip install llama-recipes transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n", "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n", "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 1: Load the model\n", "\n", "Point model_id to model weight folder" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please run\n", "\n", "python -m bitsandbytes\n", "\n", " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "bin /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda112.so\n", "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so\n", "CUDA SETUP: Highest compute capability among GPUs detected: 8.0\n", "CUDA SETUP: Detected CUDA version 112\n", "CUDA SETUP: Loading binary /data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda112.so...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: /data/home/mreso/miniconda3/envs/llama did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths...\n", " warn(msg)\n", "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/local/cuda/efa/lib')}\n", " warn(msg)\n", "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n", "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:10<00:00, 5.09s/it]\n" ] } ], "source": [ "import torch\n", "from transformers import LlamaForCausalLM, LlamaTokenizer\n", "\n", "model_id=\"./models_hf/7B\"\n", "\n", "tokenizer = LlamaTokenizer.from_pretrained(model_id)\n", "\n", "model =LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 2: Load the preprocessed dataset\n", "\n", "We load and preprocess the samsum dataset which consists of curated pairs of dialogs and their summarization:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset samsum (/data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e)\n", "Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-b14554a76c1c7ecd.arrow\n", "Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-e40e61e15ebeb527.arrow\n", "Loading cached processed dataset at /data/home/mreso/.cache/huggingface/datasets/samsum/samsum/0.0.0/f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e/cache-e08ac9e1b792e7ba.arrow\n" ] } ], "source": [ "from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n", "from llama_recipes.configs.datasets import samsum_dataset\n", "\n", "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 3: Check base model\n", "\n", "Run the base model on an example input:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Summarize this dialog:\n", "A: Hi Tom, are you busy tomorrow’s afternoon?\n", "B: I’m pretty sure I am. What’s up?\n", "A: Can you go with me to the animal shelter?.\n", "B: What do you want to do?\n", "A: I want to get a puppy for my son.\n", "B: That will make him so happy.\n", "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n", "A: I'll get him one of those little dogs.\n", "B: One that won't grow up too big;-)\n", "A: And eat too much;-))\n", "B: Do you know which one he would like?\n", "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n", "B: I bet you had to drag him away.\n", "A: He wanted to take it home right away ;-).\n", "B: I wonder what he'll name it.\n", "A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))\n", "---\n", "Summary:\n", "A: Hi Tom, are you busy tomorrow’s afternoon?\n", "B: I’m pretty sure I am. What’s up?\n", "A: Can you go with me to the animal shelter?.\n", "B: What do you want to do?\n", "A: I want to get a puppy for my son.\n", "B: That will make him so happy.\n", "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", "B\n" ] } ], "source": [ "eval_prompt = \"\"\"\n", "Summarize this dialog:\n", "A: Hi Tom, are you busy tomorrow’s afternoon?\n", "B: I’m pretty sure I am. What’s up?\n", "A: Can you go with me to the animal shelter?.\n", "B: What do you want to do?\n", "A: I want to get a puppy for my son.\n", "B: That will make him so happy.\n", "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n", "A: I'll get him one of those little dogs.\n", "B: One that won't grow up too big;-)\n", "A: And eat too much;-))\n", "B: Do you know which one he would like?\n", "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n", "B: I bet you had to drag him away.\n", "A: He wanted to take it home right away ;-).\n", "B: I wonder what he'll name it.\n", "A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))\n", "---\n", "Summary:\n", "\"\"\"\n", "\n", "model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n", "\n", "model.eval()\n", "with torch.no_grad():\n", " print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can see that the base model only repeats the conversation.\n", "\n", "### Step 4: Prepare model for PEFT\n", "\n", "Let's prepare the model for Parameter Efficient Fine Tuning (PEFT):" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 4194304 || all params: 6742609920 || trainable%: 0.06220594176090199\n" ] } ], "source": [ "model.train()\n", "\n", "def create_peft_config(model):\n", " from peft import (\n", " get_peft_model,\n", " LoraConfig,\n", " TaskType,\n", " prepare_model_for_kbit_training,\n", " )\n", "\n", " peft_config = LoraConfig(\n", " task_type=TaskType.CAUSAL_LM,\n", " inference_mode=False,\n", " r=8,\n", " lora_alpha=32,\n", " lora_dropout=0.05,\n", " target_modules = [\"q_proj\", \"v_proj\"]\n", " )\n", "\n", " # prepare int-8 model for training\n", " model = prepare_model_for_kbit_training(model)\n", " model = get_peft_model(model, peft_config)\n", " model.print_trainable_parameters()\n", " return model, peft_config\n", "\n", "# create peft config\n", "model, lora_config = create_peft_config(model)\n", "\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "### Step 5: Define an optional profiler" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from transformers import TrainerCallback\n", "from contextlib import nullcontext\n", "enable_profiler = False\n", "output_dir = \"tmp/llama-output\"\n", "\n", "config = {\n", " 'lora_config': lora_config,\n", " 'learning_rate': 1e-4,\n", " 'num_train_epochs': 1,\n", " 'gradient_accumulation_steps': 2,\n", " 'per_device_train_batch_size': 2,\n", " 'gradient_checkpointing': False,\n", "}\n", "\n", "# Set up profiler\n", "if enable_profiler:\n", " wait, warmup, active, repeat = 1, 1, 2, 1\n", " total_steps = (wait + warmup + active) * (1 + repeat)\n", " schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)\n", " profiler = torch.profiler.profile(\n", " schedule=schedule,\n", " on_trace_ready=torch.profiler.tensorboard_trace_handler(f\"{output_dir}/logs/tensorboard\"),\n", " record_shapes=True,\n", " profile_memory=True,\n", " with_stack=True)\n", " \n", " class ProfilerCallback(TrainerCallback):\n", " def __init__(self, profiler):\n", " self.profiler = profiler\n", " \n", " def on_step_end(self, *args, **kwargs):\n", " self.profiler.step()\n", "\n", " profiler_callback = ProfilerCallback(profiler)\n", "else:\n", " profiler = nullcontext()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 6: Fine tune the model\n", "\n", "Here, we fine tune the model for a single epoch which takes a bit more than an hour on a A100." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n", "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:321: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n", "/data/home/mreso/miniconda3/envs/llama/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:321: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n", " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [389/389 1:12:06, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
101.965000
201.845600
301.801100
401.780900
501.715400
601.697800
701.707600
801.713300
901.663900
1001.702700
1101.658800
1201.692400
1301.644900
1401.687900
1501.686600
1601.649600
1701.666900
1801.709200
1901.670400
2001.662700
2101.681300
2201.685500
2301.663400
2401.638300
2501.627400
2601.654300
2701.640900
2801.674700
2901.657300
3001.660200
3101.666600
3201.674500
3301.656200
3401.684300
3501.667900
3601.661400
3701.676800
3801.628100

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from transformers import default_data_collator, Trainer, TrainingArguments\n", "\n", "\n", "\n", "# Define training args\n", "training_args = TrainingArguments(\n", " output_dir=output_dir,\n", " overwrite_output_dir=True,\n", " bf16=True, # Use BF16 if available\n", " # logging strategies\n", " logging_dir=f\"{output_dir}/logs\",\n", " logging_strategy=\"steps\",\n", " logging_steps=10,\n", " save_strategy=\"no\",\n", " optim=\"adamw_torch_fused\",\n", " max_steps=total_steps if enable_profiler else -1,\n", " **{k:v for k,v in config.items() if k != 'lora_config'}\n", ")\n", "\n", "with profiler:\n", " # Create Trainer instance\n", " trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " data_collator=default_data_collator,\n", " callbacks=[profiler_callback] if enable_profiler else [],\n", " )\n", "\n", " # Start training\n", " trainer.train()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 7:\n", "Save model checkpoint" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "model.save_pretrained(output_dir)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Step 8:\n", "Try the fine tuned model on the same example again to see the learning progress:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Summarize this dialog:\n", "A: Hi Tom, are you busy tomorrow’s afternoon?\n", "B: I’m pretty sure I am. What’s up?\n", "A: Can you go with me to the animal shelter?.\n", "B: What do you want to do?\n", "A: I want to get a puppy for my son.\n", "B: That will make him so happy.\n", "A: Yeah, we’ve discussed it many times. I think he’s ready now.\n", "B: That’s good. Raising a dog is a tough issue. Like having a baby ;-) \n", "A: I'll get him one of those little dogs.\n", "B: One that won't grow up too big;-)\n", "A: And eat too much;-))\n", "B: Do you know which one he would like?\n", "A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n", "B: I bet you had to drag him away.\n", "A: He wanted to take it home right away ;-).\n", "B: I wonder what he'll name it.\n", "A: He said he’d name it after his dead hamster – Lemmy - he's a great Motorhead fan :-)))\n", "---\n", "Summary:\n", "A wants to get a puppy for his son. He took him to the animal shelter last Monday. He showed him one that he really liked. A will name it after his dead hamster - Lemmy.\n" ] } ], "source": [ "model.eval()\n", "with torch.no_grad():\n", " print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" }, "vscode": { "interpreter": { "hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146" } } }, "nbformat": 4, "nbformat_minor": 4 }