Forráskód Böngészése

merge main branch

Hamid Shojanazeri 1 éve
szülő
commit
5b916114eb

+ 79 - 0
.github/ISSUE_TEMPLATE/bug.yml

@@ -0,0 +1,79 @@
+name: 🐛 Bug Report
+description: Create a report to help us reproduce and fix the bug
+
+body:
+  - type: markdown
+    attributes:
+      value: >
+        #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the
+        existing and past issues](https://github.com/facebookresearch/llama-recipes/issues), the [FAQ](https://github.com/facebookresearch/llama-recipes/blob/main/docs/FAQ.md) 
+
+  - type: textarea
+    id: system-info
+    attributes:
+      label: System Info
+      description: |
+        Please share your system info with us. You can use the following command to capture your environment information
+        python -m "torch.utils.collect_env"
+
+      placeholder: | 
+        PyTorch version, CUDA version, GPU type, #num of GPUs...   
+    validations:
+      required: true
+
+  - type: checkboxes
+    id: information-scripts-examples
+    attributes:
+      label: Information
+      description: 'The problem arises when using:'
+      options:
+        - label: "The official example scripts"
+        - label: "My own modified scripts"
+
+  - type: textarea
+    id: bug-description
+    attributes:
+      label: 🐛 Describe the bug
+      description: |
+        Please provide a clear and concise description of what the bug is.
+
+        Provide the exact command(s) that you ran with the settings eg using FSDP and PEFT or pure FSDP.
+        
+        Please also paste or describe the results you observe instead of the expected results. 
+      placeholder: |
+        A clear and concise description of what the bug is.
+        
+        ```python
+        # Command that you used for running the examples
+        ```
+        Description of the results
+    validations:
+      required: true
+
+  - type: textarea
+    attributes:
+      label: Error logs
+      description: |
+       If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
+
+      placeholder: |
+        ```
+        The error message you got, with the full traceback.
+        ```
+
+    validations:
+      required: true
+
+  
+  - type: textarea
+    id: expected-behavior
+    validations:
+      required: true
+    attributes:
+      label: Expected behavior
+      description: "A clear and concise description of what you would expect to happen."
+
+  - type: markdown
+    attributes:
+      value: >
+        Thanks for contributing 🎉!

+ 31 - 0
.github/ISSUE_TEMPLATE/feature-request.yml

@@ -0,0 +1,31 @@
+name: 🚀 Feature request
+description: Submit a proposal/request for a new llama-recipes feature
+
+body:
+- type: textarea
+  id: feature-pitch
+  attributes:
+    label: 🚀 The feature, motivation and pitch
+    description: >
+      A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
+  validations:
+    required: true
+
+- type: textarea
+  id: alternatives
+  attributes:
+    label: Alternatives
+    description: >
+      A description of any alternative solutions or features you've considered, if any.
+
+- type: textarea
+  id: additional-context
+  attributes:
+    label: Additional context
+    description: >
+      Add any other context or screenshots about the feature request.
+
+- type: markdown
+  attributes:
+    value: >
+      Thanks for contributing 🎉!

+ 38 - 0
.github/PULL_REQUEST_TEMPLATE.md

@@ -0,0 +1,38 @@
+# What does this PR do?
+
+<!--
+Congratulations! You've made it this far! You're not quite done yet though.
+
+Please include a good title that fully reflects the extent of your awesome contribution.
+
+Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
+
+-->
+
+<!-- Remove if not applicable -->
+
+Fixes # (issue)
+
+
+## Feature/Issue validation/testing
+
+Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
+Please also list any relevant details for your test configuration.
+
+- [ ] Test A
+Logs for Test A
+
+- [ ] Test B
+Logs for Test B
+
+
+## Before submitting
+- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
+- [ ] Did you read the [contributor guideline](https://github.com/facebookresearch/llama-recipes/blob/main/CONTRIBUTING.md#pull-requests),
+      Pull Request section?
+- [ ] Was this discussed/approved via a Github issue? Please add a link
+      to it if that's the case.
+- [ ] Did you make sure to update the documentation with your changes?  
+- [ ] Did you write any new necessary tests?
+
+Thanks for contributing 🎉!

+ 66 - 0
.github/workflows/spellcheck.yml

@@ -0,0 +1,66 @@
+name: SpellCheck
+
+on:
+  push:
+    branches:
+      - main
+  pull_request:
+    branches:
+      - main
+jobs:
+  build:
+    runs-on: ubuntu-20.04
+    name: Lint changed files
+    steps:
+      - uses: actions/checkout@v3
+        with:
+          fetch-depth: 0  # OR "2" -> To retrieve the preceding commit.
+
+      - name: Check links in all markdown files
+        uses: gaurav-nelson/github-action-markdown-link-check@1.0.13
+        with:
+          use-verbose-mode: 'yes'
+          config-file: "scripts/markdown_link_check_config.json"
+
+      - name: Get changed files
+        id: changed-files
+        uses: tj-actions/changed-files@v29.0.4
+        with:
+
+          files: |
+            **/*.py
+
+  spellcheck:
+    runs-on: ubuntu-20.04
+    steps:
+      - uses: actions/checkout@v3
+
+      - name: Install dependencies
+        run: |
+          sudo apt-get install aspell aspell-en
+          pip install pyspelling
+
+      - name: Get changed files
+        id: changed-files
+        uses: tj-actions/changed-files@v29.0.4
+        with:
+          files: |
+            **/*.md
+
+      - name: Check spellings
+        run: |
+          sources=""
+          for file in ${{ steps.changed-files.outputs.all_changed_files }}; do
+            sources="${sources} -S $file"
+          done
+          if [ ! "$sources" ]; then
+            echo "No files to spellcheck"
+          else
+            pyspelling -c $GITHUB_WORKSPACE/scripts/spellcheck_conf/spellcheck.yaml --name Markdown $sources
+          fi
+
+      - name: In the case of misspellings
+        if: ${{ failure() }}
+        run: |
+          echo "Please fix the misspellings. If you are sure about some of them, "
+          echo "so append those to scripts/spellcheck_conf/wordlist.txt"

+ 20 - 17
README.md

@@ -7,12 +7,12 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
 
 # Table of Contents
 1. [Quick start](#quick-start)
-2. [Fine-tuning](#fine-tuning)
+2. [Model Conversion](#model-conversion-to-hugging-face)
+3. [Fine-tuning](#fine-tuning)
     - [Single GPU](#single-gpu)
     - [Multi GPU One Node](#multiple-gpus-one-node)
     - [Multi GPU Multi Node](#multi-gpu-multi-node)
-3. [Inference](./docs/inference.md)
-4. [Model Conversion](#model-conversion-to-hugging-face)
+4. [Inference](./docs/inference.md)
 5. [Repository Organization](#repository-organization)
 6. [License and Acceptable Use Policy](#license)
 
@@ -46,6 +46,23 @@ pip install -r requirements.txt
 
 **Please note that the above requirements.txt will install PyTorch 2.0.1 version, in case you want to run FSDP + PEFT, please make sure to install PyTorch nightlies.**
 
+# Model conversion to Hugging Face
+The recipes and notebooks in this folder are using the Llama 2 model definition provided by Hugging Face's transformers library.
+
+Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with:
+
+```bash
+## Install HuggingFace Transformers from source
+pip freeze | grep transformers ## verify it is version 4.31.0 or higher
+
+```bash
+git clone git@github.com:huggingface/transformers.git
+cd transformers
+pip install protobuf
+python src/transformers/models/llama/convert_llama_weights_to_hf.py \
+   --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
+```
+
 # Fine-tuning
 
 For fine-tuning Llama 2 models for your domain-specific use cases recipes for PEFT, FSDP, PEFT+FSDP have been included along with a few test datasets. For details see [LLM Fine-tuning](./docs/LLM_finetuning.md).
@@ -112,20 +129,6 @@ sbatch multi_node.slurm
 You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.md).
 
 
-# Model conversion to Hugging Face
-The recipes and notebooks in this folder are using the Llama 2 model definition provided by Hugging Face's transformers library.
-
-Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with:
-
-```bash
-## Install HuggingFace Transformers from source
-pip install git+https://github.com/huggingface/transformers
-cd transformers
-
-python src/transformers/models/llama/convert_llama_weights_to_hf.py \
-    --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir models_hf/7B
-```
-
 # Repository Organization
 This repository is organized in the following way:
 

+ 1 - 1
configs/fsdp.py

@@ -13,7 +13,7 @@ class fsdp_config:
     sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
     checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
     fsdp_activation_checkpointing: bool=True
-    pure_bf16: bool = True
+    pure_bf16: bool = False
     optimizer: str= "AdamW"
     
     

+ 2 - 2
docs/Dataset.md

@@ -10,7 +10,7 @@ The provided fine tuning script allows you to select between three datasets by p
 
 The list of available datasets can easily be extended with custom datasets by following these instructions.
 
-Each dataset has a corresponding configuration (dataclass) in [configs/dataset.py](../configs/dataset.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
+Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 
 Additionally, there is a preprocessing function for each dataset in the [ft_datasets](../ft_datasets) folder.
 The returned data of the dataset needs to be consumable by the forward method of the fine-tuned model by calling ```model(**data)```.
@@ -18,7 +18,7 @@ For CausalLM models this usually means that the data needs to be in the form of
 
 To add a custom dataset the following steps need to be performed.
 
-1. Create a dataset configuration after the schema described above. Examples can be found in [configs/dataset.py](../configs/dataset.py).
+1. Create a dataset configuration after the schema described above. Examples can be found in [configs/datasets.py](../configs/datasets.py).
 2. Create a preprocessing routine which loads the data and returns a PyTorch style dataset. The signature for the preprocessing function needs to be (dataset_config, tokenizer, split_name) where split_name will be the string for train/validation split as defined in the dataclass.
 3. Register the dataset name and preprocessing function by inserting it as key and value into the DATASET_PREPROC dictionary in [utils/dataset_utils.py](../utils/dataset_utils.py)
 4. Set dataset field in training config to dataset name or use --dataset option of the llama_finetuning.py training script.

+ 1 - 1
docs/inference.md

@@ -31,7 +31,7 @@ inference/samsum_prompt.txt
 The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:
 
 ```bash
-python chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file chats.json  --quantization --use_auditnlg
+python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json  --quantization --use_auditnlg
 
 ```
 ## Loading back FSDP checkpoints

+ 1 - 1
llama_finetuning.py

@@ -134,7 +134,7 @@ def main(**kwargs):
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.cuda.current_device(),
-            limit_all_gathers=False,
+            limit_all_gathers=True,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             policies.apply_fsdp_checkpointing(model)

+ 1 - 1
model_checkpointing/checkpoint_handler.py

@@ -212,7 +212,7 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
 
 
 def load_optimizer_checkpoint(model, optimizer, rank, cfg):
-    """load an fdsp optimizer full_state checkpoint using scatter method
+    """load an fsdp optimizer full_state checkpoint using scatter method
     this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
     """
 

+ 1 - 1
policies/activation_checkpointing_functions.py

@@ -26,7 +26,7 @@ def apply_fsdp_checkpointing(model):
     """apply activation checkpointing to model
     returns None as model is updated directly
     """
-    print(f"--> applying fdsp activation checkpointing...")
+    print(f"--> applying fsdp activation checkpointing...")
 
     apply_activation_checkpointing(
         model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn

+ 27 - 0
scripts/markdown_link_check_config.json

@@ -0,0 +1,27 @@
+{
+  "retryOn429": true,
+  "retryCount": 5,
+  "fallbackRetryDelay": "10s",
+  "httpHeaders": [
+    {
+      "urls": [
+        "https://docs.github.com/",
+        "https://help.github.com/"
+      ],
+      "headers": {
+        "Accept-Encoding": "zstd, br, gzip, deflate"
+      }
+    }
+  ],
+  "ignorePatterns": [
+    {
+      "pattern": "^http(s)?://127.0.0.1.*"
+    },
+    {
+      "pattern": "^http(s)?://localhost.*"
+    },
+    {
+      "pattern": "https://www.intel.com/content/www/us/en/developer/articles/news/llama2.html"
+    }
+  ]
+}

+ 20 - 0
scripts/spellcheck.sh

@@ -0,0 +1,20 @@
+# Source: https://github.com/pytorch/torchx/blob/main/scripts/spellcheck.sh
+set -ex
+sudo apt-get install aspell
+
+if [[ -z "$@" ]]; then
+    sources=$(find -name '*.md')
+else
+    sources=$@
+fi
+
+sources_arg=""
+for src in $sources; do
+        sources_arg="${sources_arg} -S $src"
+done
+
+if [ ! "$sources_arg" ]; then
+	echo "No files to spellcheck"
+else
+	pyspelling -c scripts/spellcheck_conf/spellcheck.yaml --name Markdown $sources_arg
+fi

+ 22 - 0
scripts/spellcheck_conf/spellcheck.yaml

@@ -0,0 +1,22 @@
+matrix:
+- name: Markdown
+  apsell:
+    lang: en
+    d: en_US
+  dictionary:
+    wordlists:
+    - scripts/spellcheck_conf/wordlist.txt
+    output: scripts/spellcheck_conf/wordlist.dic
+    encoding: utf-8
+  pipeline:
+  - pyspelling.filters.context:
+      context_visible_first: true
+      delimiters:
+      - open: '(?s)^ *(?P<open>`{3,})[a-z0-9]*?$'
+        close: '^(?P=open)$'
+      - open: ''
+        content: 'https?://[-a-zA-Z0-9.]+?\.[a-z]{2,6}[-?=&%.0-9a-zA-Z/_#]*'
+        close: ''
+  - pyspelling.filters.markdown:
+      markdown_extensions:
+      - markdown.extensions.extra:

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 1070 - 0
scripts/spellcheck_conf/wordlist.txt


+ 1 - 0
utils/memory_utils.py

@@ -50,6 +50,7 @@ class MemoryTrace:
         self.end = byte2gb(torch.cuda.memory_allocated())
         self.peak = byte2gb(torch.cuda.max_memory_allocated())
         cuda_info = torch.cuda.memory_stats()
+        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
         self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
         self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
         self.used = byte2gb(self.end - self.begin)

+ 5 - 5
utils/train_utils.py

@@ -84,9 +84,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda')       
-                outputs = model(**batch)
-                loss = outputs.loss
+                        batch[key] = batch[key].to('cuda:0')              
+                loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
                 first_key = next(iter(batch))
@@ -105,7 +104,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         optimizer.step()
                         optimizer.zero_grad()
                         
-                print(f"\n step {step} is completed and loss is {loss.detach().float()}")        
+                print(f"\n step {step} is completed and loss is {loss.detach().float()}")
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -117,6 +116,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         print(f"Max CUDA memory allocated was {memtrace.peak} GB")
         print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+        print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
         print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
         print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
         
@@ -202,7 +202,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                 else:
-                    batch[key] = batch[key].to('cuda')
+                    batch[key] = batch[key].to('cuda:0')
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss