123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import os
- from pkg_resources import packaging
- import fire
- import random
- import torch
- import torch.optim as optim
- from peft import get_peft_model, prepare_model_for_int8_training
- from torch.distributed.fsdp import (
- FullyShardedDataParallel as FSDP,
- )
- from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
- from torch.optim.lr_scheduler import StepLR
- from transformers import (
- LlamaForCausalLM,
- LlamaTokenizer,
- LlamaConfig,
- )
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
- from llama_recipes.configs import fsdp_config as FSDP_CONFIG
- from llama_recipes.configs import train_config as TRAIN_CONFIG
- from llama_recipes.data.concatenator import ConcatDataset
- from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
- from llama_recipes.utils import fsdp_auto_wrap_policy
- from llama_recipes.utils.config_utils import (
- update_config,
- generate_peft_config,
- generate_dataset_config,
- get_dataloader_kwargs,
- )
- from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
- from llama_recipes.utils.train_utils import (
- train,
- freeze_transformer_layers,
- setup,
- setup_environ_flags,
- clear_gpu_cache,
- print_model_size,
- get_policies
- )
- from accelerate.utils import is_xpu_available
- def main(**kwargs):
- # Update the configuration for the training and sharding process
- train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
- update_config((train_config, fsdp_config), **kwargs)
- # Set the seeds for reproducibility
- if is_xpu_available():
- torch.xpu.manual_seed(train_config.seed)
- else:
- torch.cuda.manual_seed(train_config.seed)
- torch.manual_seed(train_config.seed)
- random.seed(train_config.seed)
- if train_config.enable_fsdp:
- setup()
- # torchrun specific
- local_rank = int(os.environ["LOCAL_RANK"])
- rank = int(os.environ["RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
- if torch.distributed.is_initialized():
- if is_xpu_available():
- torch.xpu.set_device(local_rank)
- else:
- torch.cuda.set_device(local_rank)
- clear_gpu_cache(local_rank)
- setup_environ_flags(rank)
- # Load the pre-trained model and setup its configuration
- use_cache = False if train_config.enable_fsdp else None
- if train_config.enable_fsdp and train_config.low_cpu_fsdp:
- """
- for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
- this avoids cpu oom when loading large models like llama 70B, in which case
- model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
- overhead and currently requires latest nightly.
- """
- v = packaging.version.parse(torch.__version__)
- verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
- if not verify_latest_nightly:
- raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
- "please install latest nightly.")
- if rank == 0:
- model = LlamaForCausalLM.from_pretrained(
- train_config.model_name,
- load_in_8bit=True if train_config.quantization else None,
- device_map="auto" if train_config.quantization else None,
- use_cache=use_cache,
- attn_implementation="sdpa" if train_config.use_fast_kernels else None,
- )
- else:
- llama_config = LlamaConfig.from_pretrained(train_config.model_name)
- llama_config.use_cache = use_cache
- with torch.device("meta"):
- model = LlamaForCausalLM(llama_config)
- else:
- model = LlamaForCausalLM.from_pretrained(
- train_config.model_name,
- load_in_8bit=True if train_config.quantization else None,
- device_map="auto" if train_config.quantization else None,
- use_cache=use_cache,
- attn_implementation="sdpa" if train_config.use_fast_kernels else None,
- )
- # Load the tokenizer and add special tokens
- tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
- tokenizer.pad_token_id = tokenizer.eos_token_id
- print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
- # Prepare the model for int8 training if quantization is enabled
- if train_config.quantization:
- model = prepare_model_for_int8_training(model)
- # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
- if train_config.enable_fsdp and fsdp_config.pure_bf16:
- model.to(torch.bfloat16)
- if train_config.use_peft:
- peft_config = generate_peft_config(train_config, kwargs)
- model = get_peft_model(model, peft_config)
- model.print_trainable_parameters()
- #setting up FSDP if enable_fsdp is enabled
- if train_config.enable_fsdp:
- if not train_config.use_peft and train_config.freeze_layers:
- freeze_transformer_layers(train_config.num_freeze_layers)
- mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
- my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
- model = FSDP(
- model,
- auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
- cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
- mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
- sharding_strategy=fsdp_config.sharding_strategy,
- device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
- limit_all_gathers=True,
- sync_module_states=train_config.low_cpu_fsdp,
- param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
- if train_config.low_cpu_fsdp and rank != 0 else None,
- )
- if fsdp_config.fsdp_activation_checkpointing:
- apply_fsdp_checkpointing(model)
- elif not train_config.quantization and not train_config.enable_fsdp:
- if is_xpu_available():
- model.to("xpu:0")
- else:
- model.to("cuda")
- dataset_config = generate_dataset_config(train_config, kwargs)
- # Load and preprocess the dataset for training and validation
- dataset_train = get_preprocessed_dataset(
- tokenizer,
- dataset_config,
- split="train",
- )
- if not train_config.enable_fsdp or rank == 0:
- print(f"--> Training Set Length = {len(dataset_train)}")
- dataset_val = get_preprocessed_dataset(
- tokenizer,
- dataset_config,
- split="test",
- )
- if not train_config.enable_fsdp or rank == 0:
- print(f"--> Validation Set Length = {len(dataset_val)}")
- if train_config.batching_strategy == "packing":
- dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
- train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
- # Create DataLoaders for the training and validation dataset
- train_dataloader = torch.utils.data.DataLoader(
- dataset_train,
- num_workers=train_config.num_workers_dataloader,
- pin_memory=True,
- **train_dl_kwargs,
- )
- eval_dataloader = None
- if train_config.run_validation:
- if train_config.batching_strategy == "packing":
- dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
- val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
- eval_dataloader = torch.utils.data.DataLoader(
- dataset_val,
- num_workers=train_config.num_workers_dataloader,
- pin_memory=True,
- **val_dl_kwargs,
- )
- # Initialize the optimizer and learning rate scheduler
- if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
- optimizer = AnyPrecisionAdamW(
- model.parameters(),
- lr=train_config.lr,
- momentum_dtype=torch.bfloat16,
- variance_dtype=torch.bfloat16,
- use_kahan_summation=False,
- weight_decay=train_config.weight_decay,
- )
- else:
- optimizer = optim.AdamW(
- model.parameters(),
- lr=train_config.lr,
- weight_decay=train_config.weight_decay,
- )
- scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
- # Start the training process
- results = train(
- model,
- train_dataloader,
- eval_dataloader,
- tokenizer,
- optimizer,
- scheduler,
- train_config.gradient_accumulation_steps,
- train_config,
- fsdp_config if train_config.enable_fsdp else None,
- local_rank if train_config.enable_fsdp else None,
- rank if train_config.enable_fsdp else None,
- )
- if not train_config.enable_fsdp or rank==0:
- [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
- if __name__ == "__main__":
- fire.Fire(main)
|