Browse Source

fine-tune school math

Alex Skidanov 6 months atrás
parent
commit
efd1db3754

+ 19 - 10
src/llama_recipes/configs/datasets.py

@@ -3,32 +3,41 @@
 
 from dataclasses import dataclass
 
-    
+
 @dataclass
 class samsum_dataset:
-    dataset: str =  "samsum_dataset"
+    dataset: str = "samsum_dataset"
+    train_split: str = "train"
+    test_split: str = "validation"
+
+
+@dataclass
+class school_math_dataset:
+    dataset: str = "school_math_dataset"
     train_split: str = "train"
     test_split: str = "validation"
-    
-    
+
+
 @dataclass
 class grammar_dataset:
     dataset: str = "grammar_dataset"
-    train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
-    test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
+    train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv"
+    test_split: str = (
+        "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
+    )
+
 
-    
 @dataclass
 class alpaca_dataset:
     dataset: str = "alpaca_dataset"
     train_split: str = "train"
     test_split: str = "val"
     data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
-    
-    
+
+
 @dataclass
 class custom_dataset:
     dataset: str = "custom_dataset"
     file: str = "examples/custom_dataset.py"
     train_split: str = "train"
-    test_split: str = "validation"
+    test_split: str = "validation"

+ 1 - 1
src/llama_recipes/configs/training.py

@@ -28,7 +28,7 @@ class train_config:
     use_fp16: bool=False
     mixed_precision: bool=True
     val_batch_size: int=1
-    dataset = "samsum_dataset"
+    dataset = "school_math_dataset"
     peft_method: str = "lora" # None , llama_adapter, prefix
     use_peft: bool=False
     output_dir: str = "PATH/to/save/PEFT/model"

+ 12 - 3
src/llama_recipes/datasets/__init__.py

@@ -1,6 +1,15 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
-from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
-from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
+from llama_recipes.datasets.grammar_dataset.grammar_dataset import (
+    get_dataset as get_grammar_dataset,
+)
+from llama_recipes.datasets.alpaca_dataset import (
+    InstructionDataset as get_alpaca_dataset,
+)
+from llama_recipes.datasets.samsum_dataset import (
+    get_preprocessed_samsum as get_samsum_dataset,
+)
+from llama_recipes.datasets.school_math_dataset import (
+    get_preprocessed_school_math as get_school_math_dataset,
+)

+ 9 - 7
src/llama_recipes/datasets/samsum_dataset.py

@@ -10,9 +10,7 @@ import datasets
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
 
-    prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
-    )
+    prompt = f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
 
     def apply_prompt_template(sample):
         return {
@@ -23,14 +21,18 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
 
     def tokenize_add_label(sample):
-        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
-        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+        prompt = tokenizer.encode(
+            tokenizer.bos_token + sample["prompt"], add_special_tokens=False
+        )
+        summary = tokenizer.encode(
+            sample["summary"] + tokenizer.eos_token, add_special_tokens=False
+        )
 
         sample = {
             "input_ids": prompt + summary,
-            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "attention_mask": [1] * (len(prompt) + len(summary)),
             "labels": [-100] * len(prompt) + summary,
-            }
+        }
 
         return sample
 

+ 66 - 0
src/llama_recipes/datasets/school_math_dataset.py

@@ -0,0 +1,66 @@
+import json
+import random
+
+import datasets
+
+
+def get_preprocessed_school_math(dataset_config, tokenizer, split):
+    with open("/home/setup/.datasets/jasnah-school-math/precise/grades4to6.json") as f:
+        data = json.load(f)
+
+    # Shuffle data
+    rng = random.Random(42)
+    order = list(range(len(data)))
+    rng.shuffle(order)
+    data = [data[i] for i in order]
+
+    train_percent = 70
+    train_count = int(len(data) * train_percent / 100)
+
+    if split == "train":
+        data = data[:train_count]
+    else:
+        data = data[train_count:]
+
+    data = {
+        "problem": [d["problem"] for d in data],
+        "solution": [d["solution"] for d in data],
+        "answer": [d["answer"] for d in data],
+    }
+
+    dataset = datasets.Dataset.from_dict(data)
+
+    prompt_template = f"Solve the following math problem based on the provided description. The description provided is in Russian and uses LaTeX in some places.\nYou should include a detailed step-by-step explanation of your solution. The final response must clearly present the precise answer.\n<problem>\n{{problem}}\n</problem>\n"
+    response_template = (
+        f"<solution>{{solution}}</solution>\n<answer>{{answer}}</answer>\n"
+    )
+
+    def apply_prompt_template(sample):
+        return {
+            "prompt": prompt_template.format(problem=sample["problem"]),
+            "response": response_template.format(
+                solution=sample["solution"], answer=sample["answer"]
+            ),
+        }
+
+    dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
+
+    def tokenize_add_label(sample):
+        prompt = tokenizer.encode(
+            tokenizer.bos_token + sample["prompt"], add_special_tokens=False
+        )
+        response = tokenizer.encode(
+            sample["response"] + tokenizer.eos_token, add_special_tokens=False
+        )
+
+        sample = {
+            "input_ids": prompt + response,
+            "attention_mask": [1] * (len(prompt) + len(response)),
+            "labels": [-100] * len(prompt) + response,
+        }
+
+        return sample
+
+    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+
+    return dataset

+ 79 - 26
src/llama_recipes/finetuning.py

@@ -9,10 +9,7 @@ import random
 import torch
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_kbit_training
-from torch.distributed.fsdp import (
-    FullyShardedDataParallel as FSDP,
-    ShardingStrategy
-)
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
 
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
@@ -49,6 +46,7 @@ from llama_recipes.utils.train_utils import (
 )
 from accelerate.utils import is_xpu_available
 
+
 def setup_wandb(train_config, fsdp_config, **kwargs):
     try:
         import wandb
@@ -58,6 +56,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
             "Please install it using pip install wandb"
         )
     from llama_recipes.configs import wandb_config as WANDB_CONFIG
+
     wandb_config = WANDB_CONFIG()
     update_config(wandb_config, **kwargs)
     init_dict = dataclasses.asdict(wandb_config)
@@ -67,10 +66,34 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
     return run
 
 
+def display(config):
+    from io import StringIO
+    from re import compile
+    from termcolor import colored
+
+    buffer = StringIO()
+    print(config, file=buffer)
+    text = buffer.getvalue()
+
+    pat = compile("^(\w+)\(([\w\d]+=[^,]+(, [\w\d]+=[^,]+)*)\)$")
+
+    result = pat.match(text)
+    assert result is not None
+    name = result.group(1)
+    print()
+    print(colored(name.replace("_", " ").upper(), "blue"))
+    for key, value in map(lambda s: s.split("="), result.group(2).split(", ")):
+        print(colored(key, "green"), "=", colored(value, "red"))
+
+
 def main(**kwargs):
     # Update the configuration for the training and sharding process
     train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
+
+    display(train_config)
+    display(fsdp_config)
+
     # Set the seeds for reproducibility
     if is_xpu_available():
         torch.xpu.manual_seed(train_config.seed)
@@ -95,7 +118,7 @@ def main(**kwargs):
     wandb_run = None
 
     if train_config.use_wandb:
-        if not train_config.enable_fsdp or rank==0:
+        if not train_config.enable_fsdp or rank == 0:
             wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
     # Load the pre-trained model and setup its configuration
@@ -131,13 +154,19 @@ def main(**kwargs):
         )
 
     # Load the tokenizer and add special tokens
-    tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
+    tokenizer = AutoTokenizer.from_pretrained(
+        train_config.model_name
+        if train_config.tokenizer_name is None
+        else train_config.tokenizer_name
+    )
     tokenizer.pad_token_id = tokenizer.eos_token_id
 
-    # If there is a mismatch between tokenizer vocab size and embedding matrix, 
+    # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
-        print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
+        print(
+            "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
+        )
         model.resize_token_embeddings(len(tokenizer))
 
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
@@ -157,13 +186,18 @@ def main(**kwargs):
         if wandb_run:
             wandb_run.config.update(peft_config)
 
-
     hsdp_device_mesh = None
-    if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
-        hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
+    if (
+        fsdp_config.hsdp
+        and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
+    ):
+        hsdp_device_mesh = hsdp_device_mesh(
+            replica_group_size=fsdp_config.replica_group_size,
+            sharding_group_size=fsdp_config.sharding_group_size,
+        )
         print("HSDP device mesh is ready")
 
-    #setting up FSDP if enable_fsdp is enabled
+    # setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
 
@@ -180,16 +214,27 @@ def main(**kwargs):
 
         model = FSDP(
             model,
-            auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
-            cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
-            mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
+            auto_wrap_policy=(
+                my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
+            ),
+            cpu_offload=(
+                CPUOffload(offload_params=True)
+                if fsdp_config.fsdp_cpu_offload
+                else None
+            ),
+            mixed_precision=(
+                mixed_precision_policy if not fsdp_config.pure_bf16 else None
+            ),
             sharding_strategy=fsdp_config.sharding_strategy,
             device_mesh=hsdp_device_mesh,
             device_id=device_id,
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,
-            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
-            if train_config.low_cpu_fsdp and rank != 0 else None,
+            param_init_fn=lambda module: (
+                module.to_empty(device=torch.device("cuda"), recurse=False)
+                if train_config.low_cpu_fsdp and rank != 0
+                else None
+            ),
         )
         if fsdp_config.fsdp_activation_checkpointing:
             apply_fsdp_checkpointing(model)
@@ -201,7 +246,6 @@ def main(**kwargs):
 
     dataset_config = generate_dataset_config(train_config, kwargs)
 
-     # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
         tokenizer,
         dataset_config,
@@ -217,12 +261,16 @@ def main(**kwargs):
         split="test",
     )
     if not train_config.enable_fsdp or rank == 0:
-            print(f"--> Validation Set Length = {len(dataset_val)}")
+        print(f"--> Validation Set Length = {len(dataset_val)}")
 
     if train_config.batching_strategy == "packing":
-        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+        dataset_train = ConcatDataset(
+            dataset_train, chunk_size=train_config.context_length
+        )
 
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
+    train_dl_kwargs = get_dataloader_kwargs(
+        train_config, dataset_train, tokenizer, "train"
+    )
 
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
@@ -235,9 +283,13 @@ def main(**kwargs):
     eval_dataloader = None
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
-            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+            dataset_val = ConcatDataset(
+                dataset_val, chunk_size=train_config.context_length
+            )
 
-        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+        val_dl_kwargs = get_dataloader_kwargs(
+            train_config, dataset_val, tokenizer, "val"
+        )
 
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
@@ -279,11 +331,12 @@ def main(**kwargs):
         rank if train_config.enable_fsdp else None,
         wandb_run,
     )
-    if not train_config.enable_fsdp or rank==0:
-        [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
+    if not train_config.enable_fsdp or rank == 0:
+        [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
         if train_config.use_wandb:
-            for k,v in results.items():
+            for k, v in results.items():
                 wandb_run.summary[k] = v
 
+
 if __name__ == "__main__":
     fire.Fire(main)

+ 8 - 2
src/llama_recipes/utils/dataset_utils.py

@@ -11,6 +11,7 @@ from llama_recipes.datasets import (
     get_grammar_dataset,
     get_alpaca_dataset,
     get_samsum_dataset,
+    get_school_math_dataset,
 )
 
 
@@ -39,13 +40,17 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
 
     module_path = Path(module_path)
     if not module_path.is_file():
-        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+        raise FileNotFoundError(
+            f"Dataset py file {module_path.as_posix()} does not exist or is not a file."
+        )
 
     module = load_module_from_py_file(module_path.as_posix())
     try:
         return getattr(module, func_name)(dataset_config, tokenizer, split)
     except AttributeError as e:
-        print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
+        print(
+            f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()})."
+        )
         raise e
 
 
@@ -54,6 +59,7 @@ DATASET_PREPROC = {
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
+    "school_math_dataset": get_school_math_dataset,
 }