Browse Source

Merge branch 'main' into benchmark-infernece-throughput-onperm-vllm

Chester Hu 1 năm trước cách đây
mục cha
commit
87f6119369

+ 3 - 1
README.md

@@ -1,6 +1,6 @@
 # Llama 2 Fine-tuning / Inference Recipes, Examples and Demo Apps
 
-**[Update Dec. 15, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md).**
+**[Update Dec. 28, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md). For details on formatting data for fine tuning Llama Guard, we provide a script and sample usage [here](./src/llama_recipes/data/llama_guard/README.md).**
 
 **[Update Dec 14, 2023] We recently released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama (locally, in the cloud, or on-prem),  how to use Azure Llama 2 API (Model-as-a-Service), how to ask Llama questions in general or about custom data (PDF, DB, or live), how to integrate Llama with WhatsApp and Messenger, and how to implement an end-to-end chatbot with RAG (Retrieval Augmented Generation).**
 
@@ -110,6 +110,8 @@ All the parameters in the examples and recipes below need to be further tuned to
 
 * Make sure to set the right path to the model in the [training config](src/llama_recipes/configs/training.py).
 
+* To save the loss and perplexity metrics for evaluation, enable this by passing `--save_metrics` to the finetuning script. The file can be plotted using the [plot_metrics.py](./examples/plot_metrics.py) script, `python examples/plot_metrics.py --file_path path/to/metrics.json`
+
 ### Single GPU:
 
 ```bash

+ 21 - 2
examples/llama_guard/README.md

@@ -6,7 +6,7 @@ This folder contains an example file to run Llama Guard inference directly.
 
 ## Requirements
 1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
-2. Llama recipes dependencies installed 
+2. Llama recipes package and it's dependencies [installed](https://github.com/albertodepaola/llama-recipes/blob/llama-guard-data-formatter-example/README.md#installation)
 3. A GPU with at least 21 GB of free RAM to load both 7B models quantized.
 
 ## Llama Guard inference script
@@ -34,8 +34,27 @@ To run the samples, with all the dependencies installed, execute this command:
 
 `python examples/llama_guard/inference.py`
 
+This is the output:
+
+```
+['<Sample user prompt>']
+> safe
+
+==================================
+
+['<Sample user prompt>', '<Sample agent response>']
+> safe
+
+==================================
+
+['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']
+> safe
+
+==================================
+```
+
 ## Inference Safety Checker
-When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be show, else a message with the error will be show, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
+When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
 
 In this case, the default categories are applied by the tokenizer, using the `apply_chat_template` method.
 

+ 71 - 0
examples/plot_metrics.py

@@ -0,0 +1,71 @@
+import json
+import matplotlib.pyplot as plt
+import argparse
+import os
+
+def plot_metric(data, metric_name, x_label, y_label, title, colors):
+    plt.figure(figsize=(7, 6))
+    
+    plt.plot(data[f'train_epoch_{metric_name}'], label=f'Train Epoch {metric_name.capitalize()}', color=colors[0])
+    plt.plot(data[f'val_epoch_{metric_name}'], label=f'Validation Epoch {metric_name.capitalize()}', color=colors[1])
+    plt.xlabel(x_label)
+    plt.ylabel(y_label)
+    plt.title(f'Train and Validation Epoch {title}')
+    plt.legend()
+    plt.tight_layout()
+
+def plot_single_metric_by_step(data, metric_name, x_label, y_label, title, color):
+    plt.plot(data[f'{metric_name}'], label=f'{title}', color=color)
+    plt.xlabel(x_label)
+    plt.ylabel(y_label)
+    plt.title(title)
+    plt.legend()
+    plt.tight_layout()
+
+def plot_metrics_by_step(data, metric_name, x_label, y_label, colors):
+    plt.figure(figsize=(14, 6))
+
+    plt.subplot(1, 2, 1)
+    plot_single_metric_by_step(data, f'train_step_{metric_name}', x_label, y_label, f'Train Step {metric_name.capitalize()}', colors[0])
+    plt.subplot(1, 2, 2)
+    plot_single_metric_by_step(data, f'val_step_{metric_name}', x_label, y_label, f'Validation Step {metric_name.capitalize()}', colors[1])
+    plt.tight_layout()
+
+    
+def plot_metrics(file_path):
+    if not os.path.exists(file_path):
+        print(f"File {file_path} does not exist.")
+        return
+
+    with open(file_path, 'r') as f:
+        try:
+            data = json.load(f)
+        except json.JSONDecodeError:
+            print("Invalid JSON file.")
+            return
+
+    directory = os.path.dirname(file_path)
+    filename_prefix = os.path.basename(file_path).split('.')[0]
+
+    plot_metric(data, 'loss', 'Epoch', 'Loss', 'Loss', ['b', 'r'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss.png"))
+    plt.close()
+
+    plot_metric(data, 'perplexity', 'Epoch', 'Perplexity', 'Perplexity', ['g', 'm'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity.png"))
+    plt.close()
+
+    plot_metrics_by_step(data, 'loss', 'Step', 'Loss', ['b', 'r'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss_by_step.png"))
+    plt.close()
+
+    plot_metrics_by_step(data, 'perplexity', 'Step', 'Loss', ['g', 'm'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity_by_step.png"))
+    plt.close()
+    
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Plot metrics from JSON file.')
+    parser.add_argument('--file_path', required=True, type=str, help='Path to the metrics JSON file.')
+    args = parser.parse_args()
+
+    plot_metrics(args.file_path)

+ 2 - 1
requirements.txt

@@ -12,4 +12,5 @@ transformers>=4.34.1
 sentencepiece
 py7zr
 scipy
-optimum
+optimum
+matplotlib

+ 5 - 0
scripts/spellcheck_conf/wordlist.txt

@@ -1226,3 +1226,8 @@ jsonl
 VRAM
 HuggingFace
 llamaguard
+AugmentationConfigs
+FormatterConfigs
+LlamaGuardGenerationConfigs
+LlamaGuardPromptConfigs
+TrainingExample

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -38,3 +38,4 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    save_metrics: bool = False # saves training metrics to a json file for later plotting

+ 119 - 0
src/llama_recipes/data/llama_guard/README.md

@@ -0,0 +1,119 @@
+# Finetuning Data Formatter
+
+The finetuning_data_formatter script provides classes and methods for formatting training data for finetuning Llama Guard with a specific set of categories. The main classes are:
+* `TrainingExample`: Represents a single example in the training data, consisting of a prompt, response, label (safe or unsafe), violated category codes, and an explanation.
+* `Guidelines`: Defines the categories and their descriptions that will be used to evaluate the safety of the responses.
+* `LlamaGuardPromptConfigs`: Configures how the prompt that will be given to Llama Guard during finetuning should be formatted.
+* `LlamaGuardGenerationConfigs`: Configures how Llama Guard's response should be formatted.
+* `AugmentationConfigs`: Configures how additional examples will be generated from the original training examples to augment the training data.
+* `FormatterConfigs`: Combines all of the above configs into a single object that can be passed to the `create_formatted_finetuning_examples` method.
+
+## Running the script
+
+1. Clone the llama-recipes repo
+2. Install the dependencies
+3. Run the script with the following command: `python src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py > sample.json`
+
+## Code overview
+To use the finetuning_data_formatter, you first need to define your training examples as instances of the TrainingExample class. For example:
+
+```
+training_examples = [
+    TrainingExample(
+        prompt="Can you give me the phone number of Jane Doe?",
+        response="Jane Doe's phone number is 555-555-5555.",
+        violated_category_codes=["O1"],
+        label="unsafe",
+        explanation="The response contains personal information."
+    ),
+    # Add more training examples here...
+]
+```
+You can also define examples without a response, to train the model on user prompts only. For example:
+```
+TrainingExample(
+    prompt="What is the home address of John Smith?",
+    response="N/A",
+    violated_category_codes=["O2"],
+    label="unsafe",
+    explanation="The prompt asks for a home address."
+)
+```
+
+Next, you need to define the guidelines that will be used to evaluate the safety of the responses. You do this by creating an instance of the Guidelines class and defining the categories and their descriptions. For example:
+
+```
+guidelines = Guidelines(
+    categories=[
+        Category(name="Personal Information",
+                 description="Information that can identify an individual"),
+        Category(name="Location",
+                 description="Information about where someone lives or works")
+    ],
+    category_code_prefix="O"
+)
+```
+
+Then, you need to configure the prompt that will be given to Llama Guard during finetuning. You do this by creating an instance of the LlamaGuardPromptConfigs class and specifying the format string and other options. 
+
+**Note**: The structure of the prompts should remain the same for finetuning and inference for better performance. 
+
+For example:
+
+```
+llama_guard_prompt_configs = LlamaGuardPromptConfigs(
+    instructions_format_string="""Task: Using the following guidelines, decide whether the example conversation that comes after is safe or unsafe.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+{guidelines}
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+{conversation}
+
+<END CONVERSATION>
+
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories.  """,
+    should_include_category_descriptions=True,
+    should_shuffle_category_codes=True
+)
+```
+You also need to configure how Llama Guard's response will be generated. You do this by creating an instance of the LlamaGuardGenerationConfigs class and specifying the options. For example:
+
+```
+llama_guard_generation_configs = LlamaGuardGenerationConfigs(
+    should_list_violated_codes=True,
+    explanation_position=ExplanationPosition.AFTER_DECISION
+)
+```
+The script also provides data augmentation capabilities, configured by creating an instance of the AugmentationConfigs class and specifying the desired options. For example:
+
+```
+augmentation_configs = AugmentationConfigs(
+    should_add_examples_with_dropped_nonviolated_prompt_categories=True,
+    should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
+    explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
+)
+```
+
+Finally, you can combine all of these configs into a single FormatterConfigs object and pass it to the create_formatted_finetuning_examples method to generate the formatted training data. For example:
+
+```
+formatter_configs = FormatterConfigs(
+    guidelines=guidelines,
+    llama_guard_prompt_configs=llama_guard_prompt_configs,
+    llama_guard_generation_configs=llama_guard_generation_configs,
+    augmentation_configs=augmentation_configs,
+    random_seed=42
+)
+
+# Call the create_formatted_finetuning_examples function
+formatted_examples = create_formatted_finetuning_examples(
+    training_examples, formatter_configs)
+# Print the formatted examples
+print(formatted_examples)
+
+```

+ 10 - 10
src/llama_recipes/data/llama_guard/finetuning_data_formatter.py

@@ -63,7 +63,7 @@ class FormatterConfigs:
 class TrainingExample:
     prompt: str
     response: str
-    violated_category_codes: list[str]
+    violated_category_codes: List[str]
     label: Literal["safe", "unsafe"]
     explanation: Optional[str] = None
 
@@ -71,7 +71,7 @@ class TrainingExample:
 def create_formatted_finetuning_examples(
     training_examples: Sequence[TrainingExample],
     formatter_configs: FormatterConfigs,
-) -> list[str]:
+) -> List[str]:
     """
     This formatter takes consumer-provided training examples and converts them to
     the right format for finetuning llama-guard.
@@ -285,7 +285,7 @@ def _get_map_of_original_category_indices_to_rewritten_category_codes(
 
 def _maybe_add_data_augmentations_for_example(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
     formatter_configs: FormatterConfigs,
 ) -> None:
@@ -317,8 +317,8 @@ def _maybe_add_data_augmentations_for_example(
 
 
 def _convert_category_codes_to_indices(
-    codes: list[str], formatter_configs: FormatterConfigs
-) -> list[int]:
+    codes: List[str], formatter_configs: FormatterConfigs
+) -> List[int]:
     # Category codes start at 1, but indices start at 0, so we subtract 1
     return [
         int(code.lstrip(formatter_configs.guidelines.category_code_prefix)) - 1
@@ -328,9 +328,9 @@ def _convert_category_codes_to_indices(
 
 def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
-    nonviolated_category_indices: list[int],
+    nonviolated_category_indices: List[int],
     formatter_configs: FormatterConfigs,
 ) -> None:
     """
@@ -368,10 +368,10 @@ def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
 
 def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
-    violated_category_indices: list[int],
-    nonviolated_category_indices: list[int],
+    violated_category_indices: List[int],
+    nonviolated_category_indices: List[int],
     formatter_configs: FormatterConfigs,
 ) -> None:
     """

+ 90 - 0
src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py

@@ -0,0 +1,90 @@
+from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
+
+# Define a list of guidelines and training examples to match
+guidelines = Guidelines(
+    categories=[
+        Category(name="Personal Information",
+                 description="Information that can identify an individual"),
+        Category(name="Location",
+                 description="Information about where someone lives or works")
+    ],
+    category_code_prefix="O"
+)
+
+training_examples = [
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="Jane Doe's phone number is 555-555-5555.",
+        violated_category_codes=["O1"],
+        label="unsafe",
+        explanation="The response contains personal information."
+    ),
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="The home address of John Smith is 123 Main St, Anytown USA",
+        violated_category_codes=["O2"],
+        label="unsafe",
+        explanation="The response reveals the user's location."
+    ),
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="N/A",
+        violated_category_codes=["O2"],
+        label="unsafe",
+        explanation="The prompt asks for a home address."
+    ),
+    TrainingExample(
+        prompt="What is your favorite color?",
+        response="My favorite color is blue.",
+        violated_category_codes=[],
+        label="safe",
+        explanation="The response does not contain any sensitive information."
+    )
+]
+
+# Setting up configs
+llama_guard_prompt_configs = LlamaGuardPromptConfigs(
+    instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+{guidelines}
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+{conversation}
+
+<END CONVERSATION>
+
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
+    should_include_category_descriptions=True,
+    should_shuffle_category_codes=True
+)
+
+llama_guard_generation_configs = LlamaGuardGenerationConfigs(
+    should_list_violated_codes=True,
+    explanation_position=ExplanationPosition.AFTER_DECISION
+)
+
+augmentation_configs = AugmentationConfigs(
+    should_add_examples_with_dropped_nonviolated_prompt_categories=True,
+    should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
+    explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
+)
+
+formatter_configs = FormatterConfigs(
+    guidelines=guidelines,
+    llama_guard_prompt_configs=llama_guard_prompt_configs,
+    llama_guard_generation_configs=llama_guard_generation_configs,
+    augmentation_configs=augmentation_configs,
+    random_seed=42
+)
+
+# Call the create_formatted_finetuning_examples function
+formatted_examples = create_formatted_finetuning_examples(
+    training_examples, formatter_configs)
+
+# Print the formatted examples
+print(formatted_examples)

+ 60 - 10
src/llama_recipes/utils/train_utils.py

@@ -7,6 +7,7 @@ import yaml
 from contextlib import nullcontext
 from pathlib import Path
 from pkg_resources import packaging
+from datetime import datetime
 
 
 import torch
@@ -16,6 +17,7 @@ from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from tqdm import tqdm
 from transformers import LlamaTokenizer
+import json
 
 
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
@@ -55,13 +57,24 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"])
+        world_size = int(os.environ["WORLD_SIZE"]) 
+
+    
+
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
 
     train_prep = []
     train_loss = []
     val_prep = []
     val_loss =[]
+
+    if train_config.save_metrics:
+        metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
+        train_step_perplexity = []
+        train_step_loss = []
+        val_step_loss = []
+        val_step_perplexity = []
+        
     epoch_times = []
     checkpoint_times = []
     results = {}
@@ -82,6 +95,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 with autocast():
                     loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
+                if train_config.save_metrics:
+                    train_step_loss.append(loss.detach().float().item())
+                    train_step_perplexity.append(float(torch.exp(loss.detach().float())))
                 total_loss += loss.detach().float()
                 if train_config.use_fp16:
                     # if fp16 is enabled, use gradient scaler to handle gradient update
@@ -111,6 +127,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         pbar.update(1)
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+
+                if train_config.save_metrics:
+                    save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
             pbar.close()
 
         epoch_end_time = time.perf_counter()-epoch_start_time
@@ -122,10 +141,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.enable_fsdp:
             train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
-
-        train_prep.append(train_perplexity)
-        train_loss.append(train_epoch_loss)
-
+        
+        train_prep.append(float(train_perplexity))
+        train_loss.append(float(train_epoch_loss))
+        
         if train_config.enable_fsdp:
             if rank==0:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
@@ -144,7 +163,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         lr_scheduler.step()
 
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+            if train_config.save_metrics:
+                val_step_loss.extend(temp_val_loss)
+                val_step_perplexity.extend(temp_step_perplexity)
+
             checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp:
@@ -195,13 +218,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
                 else:
                     print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
-            val_loss.append(best_val_loss)
-            val_prep.append(eval_ppl)
+            val_loss.append(float(best_val_loss))
+            val_prep.append(float(eval_ppl))
         if train_config.enable_fsdp:
             if rank==0:
                 print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
         else:
             print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
+        
+        # Saving the results every epoch to plot later
+        if train_config.save_metrics:
+            save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+
     avg_epoch_time = sum(epoch_times)/ len(epoch_times)
     avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
     avg_train_prep = sum(train_prep)/len(train_prep)
@@ -217,6 +245,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
+    if train_config.save_metrics:
+        results["metrics_filename"] = metrics_filename
 
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
@@ -240,6 +270,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
         world_size = int(os.environ["WORLD_SIZE"])
     model.eval()
     eval_preds = []
+    val_step_loss = []
+    val_step_perplexity = []
     eval_loss = 0.0  # Initialize evaluation loss
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
@@ -253,6 +285,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 # Forward pass and compute loss
                 outputs = model(**batch)
                 loss = outputs.loss
+                if train_config.save_metrics:
+                    val_step_loss.append(loss.detach().float().item())
+                    val_step_perplexity.append(float(torch.exp(loss.detach().float())))  
+
                 eval_loss += loss.detach().float()
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)
@@ -276,8 +312,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             print(f" {eval_ppl=} {eval_epoch_loss=}")
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
-
-    return eval_ppl, eval_epoch_loss
+        
+    return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
 
 def freeze_transformer_layers(model, num_layer):
    for i, layer in enumerate(model.model.layers):
@@ -417,3 +453,17 @@ def save_train_params(train_config, fsdp_config, rank):
             f.write(config_yaml)
         if rank==0:
             print(f"training params are saved in {file_name}")
+
+def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
+    metrics_data = {
+        "train_step_loss": train_step_loss,
+        "train_epoch_loss": train_epoch_loss,
+        "train_step_perplexity": train_step_ppl,
+        "train_epoch_perplexity": train_epoch_ppl,
+        "val_step_loss": val_step_loss,
+        "val_epoch_loss": val_epoch_loss,
+        "val_step_perplexity": val_step_ppl,
+        "val_epoch_perplexity": val_epoch_ppl
+    }
+    with open(output_filename, "w") as f:
+        json.dump(metrics_data, f)

+ 3 - 3
tests/test_finetuning_data_formatter.py

@@ -3,7 +3,7 @@
 
 from enum import Enum
 import unittest
-from typing import Optional
+from typing import Optional, List
 
 from llama_recipes.data.llama_guard.finetuning_data_formatter import (
     AugmentationConfigs,
@@ -72,8 +72,8 @@ Provide your safety assessment for {agent_type_to_check} in the above conversati
 
     def format_examples_and_compare_to_expected(
         self,
-        training_examples: list[TrainingExample],
-        expected_formatted_examples: list[str],
+        training_examples: List[TrainingExample],
+        expected_formatted_examples: List[str],
         agent_type_to_check: AgentType,
         formatter_configs: Optional[FormatterConfigs] = None,
     ) -> None:

+ 53 - 0
tests/test_train_utils.py

@@ -2,11 +2,27 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 from unittest.mock import patch
+import pytest
 
 import torch
 
+import os
+import shutil
+
 from llama_recipes.utils.train_utils import train
 
+TEMP_OUTPUT_DIR = os.getcwd() + "/tmp"
+
+@pytest.fixture(scope="session")
+def temp_output_dir():
+    # Create the directory during the session-level setup
+    temp_output_dir = "tmp"
+    os.mkdir(os.path.join(os.getcwd(), temp_output_dir))
+    yield temp_output_dir
+    # Delete the directory during the session-level teardown
+    shutil.rmtree(temp_output_dir)
+
+
 @patch("llama_recipes.utils.train_utils.MemoryTrace")
 @patch("llama_recipes.utils.train_utils.nullcontext")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
@@ -28,6 +44,7 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     train_config.use_fp16 = False
     train_config.run_validation = False
     train_config.gradient_clipping = False
+    train_config.save_metrics = False
 
     train(
         model,
@@ -63,3 +80,39 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     assert optimizer.zero_grad.call_count == 3
     assert nullcontext.call_count == 0
     assert autocast.call_count == 5
+
+def test_save_to_json(temp_output_dir, mocker):
+    model = mocker.MagicMock(name="model")
+    model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    mock_tensor = mocker.MagicMock(name="tensor")
+    batch = {"input": mock_tensor}
+    train_dataloader = [batch, batch, batch, batch, batch]
+    eval_dataloader = None
+    tokenizer = mocker.MagicMock()
+    optimizer = mocker.MagicMock()
+    lr_scheduler = mocker.MagicMock()
+    gradient_accumulation_steps = 1
+    train_config = mocker.MagicMock()
+    train_config.enable_fsdp = False
+    train_config.use_fp16 = False
+    train_config.run_validation = False
+    train_config.gradient_clipping = False
+    train_config.save_metrics = True
+    train_config.output_dir = temp_output_dir
+
+    results = train(
+        model,
+        train_dataloader,
+        eval_dataloader,
+        tokenizer,
+        optimizer,
+        lr_scheduler,
+        gradient_accumulation_steps,
+        train_config,
+        local_rank=0
+    )
+
+    assert results["metrics_filename"] not in ["", None]
+    assert os.path.isfile(results["metrics_filename"])
+
+