Browse Source

Adjust imports to package structure + cleaned up imports

Matthias Reso 1 year ago
parent
commit
cf678b9bf0

+ 2 - 2
src/llama_recipes/configs/fsdp.py

@@ -1,8 +1,8 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-from dataclasses import dataclass, field
-from typing import ClassVar
+from dataclasses import dataclass
+
 from torch.distributed.fsdp import ShardingStrategy
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 

+ 1 - 1
src/llama_recipes/configs/peft.py

@@ -1,7 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-from dataclasses import dataclass, field
+from dataclasses import dataclass
 from typing import ClassVar, List
 
 @dataclass

+ 1 - 1
src/llama_recipes/configs/training.py

@@ -1,7 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
 from dataclasses import dataclass
-from typing import ClassVar
 
 
 @dataclass

+ 2 - 4
src/llama_recipes/datasets/alpaca_dataset.py

@@ -5,12 +5,10 @@
 
 import copy
 import json
-import os
-import torch
 
-from sentencepiece import SentencePieceProcessor
+import torch
 from torch.utils.data import Dataset
-from typing import List
+
 
 PROMPT_DICT = {
     "prompt_input": (

+ 2 - 18
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -4,29 +4,13 @@
 # For dataset details visit: https://huggingface.co/datasets/jfleg
 # For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
 
-import argparse
-import csv
-import glob
-import os
-import json
-import time
-import logging
-import random
-import re
-from itertools import chain
-from string import punctuation
-
-
-import pandas as pd
-import numpy as np
-import torch
-from torch.utils.data import Dataset
 
 from datasets import load_dataset
 from pathlib import Path
 
-from ft_datasets.utils import ConcatDataset
+from torch.utils.data import Dataset
 
+from ..utils import ConcatDataset
 
 
 class grammar(Dataset):

+ 1 - 0
src/llama_recipes/datasets/samsum_dataset.py

@@ -4,6 +4,7 @@
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
 import datasets
+
 from .utils import Concatenator
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):

+ 1 - 0
src/llama_recipes/datasets/utils.py

@@ -3,6 +3,7 @@
 
 from tqdm import tqdm
 from itertools import chain
+
 from torch.utils.data import Dataset
 
 class Concatenator(object):

+ 8 - 9
src/llama_recipes/finetuning.py

@@ -2,13 +2,13 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
+from pkg_resources import packaging
 
 import fire
 import torch
 import torch.distributed as dist
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
-from pkg_resources import packaging
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
 )
@@ -22,19 +22,18 @@ from transformers import (
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-import policies
-from configs import fsdp_config, train_config
-from policies import AnyPrecisionAdamW
+from .configs import fsdp_config, train_config
+from .policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
-from utils import fsdp_auto_wrap_policy
-from utils.config_utils import (
+from .utils import fsdp_auto_wrap_policy
+from .utils.config_utils import (
     update_config,
     generate_peft_config,
     generate_dataset_config,
 )
-from utils.dataset_utils import get_preprocessed_dataset
+from .utils.dataset_utils import get_preprocessed_dataset
 
-from utils.train_utils import (
+from .utils.train_utils import (
     train,
     freeze_transformer_layers,
     setup,
@@ -153,7 +152,7 @@ def main(**kwargs):
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )
         if fsdp_config.fsdp_activation_checkpointing:
-            policies.apply_fsdp_checkpointing(model)
+            apply_fsdp_checkpointing(model)
     elif not train_config.quantization and not train_config.enable_fsdp:
         model.to("cuda")
 

+ 6 - 6
src/llama_recipes/inference/chat_completion.py

@@ -2,18 +2,18 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
+
 import fire
-import torch
 import os
 import sys
-import warnings
 from typing import List
 
-from peft import PeftModel, PeftConfig
-from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
-from safety_utils import get_safety_checker
+import torch
 from model_utils import load_model, load_peft_model
-from chat_utils import read_dialogs_from_file, format_tokens
+from transformers import LlamaTokenizer
+from safety_utils import get_safety_checker
+
+from .chat_utils import read_dialogs_from_file, format_tokens
 
 def main(
     model_name,

+ 2 - 1
src/llama_recipes/inference/chat_utils.py

@@ -1,8 +1,9 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-from typing import List, Literal, Optional, Tuple, TypedDict, Union
 import json
+from typing import List, Literal, TypedDict
+
 
 Role = Literal["user", "assistant"]
 

+ 4 - 2
src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py

@@ -4,12 +4,14 @@
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
 import fire
-import torch
 import os
 import sys
 import yaml
+
 from transformers import LlamaTokenizer
-from model_utils import  load_llama_from_config
+
+from .model_utils import  load_llama_from_config
+
 # Get the current file's directory
 current_directory = os.path.dirname(os.path.abspath(__file__))
 

+ 5 - 4
src/llama_recipes/inference/inference.py

@@ -4,15 +4,16 @@
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
 import fire
-import torch
 import os
 import sys
 import time
-from typing import List
 
+import torch
 from transformers import LlamaTokenizer
-from safety_utils import get_safety_checker
-from model_utils import load_model, load_peft_model, load_llama_from_config
+
+from .safety_utils import get_safety_checker
+from .model_utils import load_model, load_peft_model
+
 
 def main(
     model_name,

+ 0 - 2
src/llama_recipes/inference/safety_utils.py

@@ -5,8 +5,6 @@ import os
 import torch
 import warnings
 
-from peft import PeftConfig
-from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
 
 # Class for performing safety checks using AuditNLG library
 class AuditNLGSensitiveTopics(object):

+ 2 - 9
src/llama_recipes/inference/vLLM_inference.py

@@ -1,20 +1,13 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 import fire
+
 import torch
-import os
-import sys
-from peft import PeftModel, PeftConfig
-from transformers import (
-    LlamaConfig,
-    LlamaTokenizer,
-    LlamaForCausalLM
-)
 from vllm import LLM
 from vllm import LLM, SamplingParams
 
+
 torch.cuda.manual_seed(42)
 torch.manual_seed(42)
 

+ 2 - 6
src/llama_recipes/policies/activation_checkpointing_functions.py

@@ -1,18 +1,14 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-import torch
-import os
-import torch.distributed as dist
+from functools import partial
+
 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
     checkpoint_wrapper,
     CheckpointImpl,
     apply_activation_checkpointing,
 )
-
-from transformers.models.t5.modeling_t5 import T5Block
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
-from functools import partial
 
 non_reentrant_wrapper = partial(
     checkpoint_wrapper,

+ 0 - 4
src/llama_recipes/policies/mixed_precision.py

@@ -4,11 +4,7 @@
 import torch
 
 from torch.distributed.fsdp import (
-    # FullyShardedDataParallel as FSDP,
-    # CPUOffload,
     MixedPrecision,
-    # BackwardPrefetch,
-    # ShardingStrategy,
 )
 
 # requires grad scaler in main loop

+ 1 - 15
src/llama_recipes/policies/wrapping.py

@@ -1,28 +1,14 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-import torch.distributed as dist
-import torch.nn as nn
-import torch
+import functools
 
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
-
-from torch.distributed.fsdp.fully_sharded_data_parallel import (
-    FullyShardedDataParallel as FSDP,
-    CPUOffload,
-    BackwardPrefetch,
-    MixedPrecision,
-)
 from torch.distributed.fsdp.wrap import (
     transformer_auto_wrap_policy,
     size_based_auto_wrap_policy,
-    enable_wrap,
-    wrap,
 )
 
-import functools
-from typing import Type
-
 
 def get_size_policy(min_params=1e8):
     num_wrap_policy = functools.partial(

+ 2 - 2
src/llama_recipes/utils/config_utils.py

@@ -3,14 +3,14 @@
 
 import inspect
 from dataclasses import fields
+
 from peft import (
     LoraConfig,
     AdaptionPromptConfig,
     PrefixTuningConfig,
 )
 
-import configs.datasets as datasets
-from configs import lora_config, llama_adapter_config, prefix_config, train_config
+from ..configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from .dataset_utils import DATASET_PREPROC
 
 

+ 3 - 4
src/llama_recipes/utils/dataset_utils.py

@@ -1,16 +1,15 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-import torch
-
 from functools import partial
 
-from ft_datasets import (
+import torch
+
+from ..datasets import (
     get_grammar_dataset,
     get_alpaca_dataset,
     get_samsum_dataset,
 )
-from typing import Optional
 
 
 DATASET_PREPROC = {

+ 0 - 3
src/llama_recipes/utils/fsdp_utils.py

@@ -3,10 +3,7 @@
 
 def fsdp_auto_wrap_policy(model, transformer_layer_name):
     import functools
-    import os
 
-    from accelerate import FullyShardedDataParallelPlugin
-    from transformers.models.t5.modeling_t5 import T5Block
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
 
     from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder

+ 2 - 4
src/llama_recipes/utils/memory_utils.py

@@ -1,12 +1,10 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
 import gc
-import os
-import sys
+import psutil
 import threading
 
-import numpy as np
-import psutil
 import torch
 
 def byte2gb(x):

+ 18 - 33
src/llama_recipes/utils/train_utils.py

@@ -2,40 +2,25 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
-import sys
-from typing import List
-import yaml
 import time
+import yaml
+from pathlib import Path
+from pkg_resources import packaging
+
 
-import fire
 import torch
-import transformers
-from datasets import load_dataset
-from tqdm import tqdm
-"""
-Unused imports:
-import torch.nn as nn
-import bitsandbytes as bnb
-"""
-from torch.nn import functional as F
-from peft import (
-    LoraConfig,
-    get_peft_model,
-    get_peft_model_state_dict,
-    prepare_model_for_int8_training,
-    set_peft_model_state_dict,
-)
-from transformers import LlamaForCausalLM, LlamaTokenizer
-from torch.distributed.fsdp import StateDictType
-import torch.distributed as dist
-from pkg_resources import packaging
-from .memory_utils import MemoryTrace
-import model_checkpointing
 import torch.cuda.nccl as nccl
+import torch.distributed as dist
+from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
-from pathlib import Path
-sys.path.append(str(Path(__file__).resolve().parent.parent))
-from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from tqdm import tqdm
+from transformers import LlamaTokenizer
+
+
+from .memory_utils import MemoryTrace
+from ..model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
+from ..policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
@@ -162,21 +147,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                         
-                        model_checkpointing.save_model_checkpoint(
+                        save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
                         
-                        model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
+                        save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
-                            model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                            save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
                             print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
                             print("=====================================================")
 
                     if not train_config.use_peft and  train_config.save_optimizer:
-                        model_checkpointing.save_optimizer_checkpoint(
+                        save_optimizer_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")