# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os from pkg_resources import packaging import dataclasses import fire import random import torch import torch.optim as optim from peft import get_peft_model, prepare_model_for_kbit_training from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy ) from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.optim.lr_scheduler import StepLR from transformers import ( AutoTokenizer, LlamaForCausalLM, LlamaConfig, ) from transformers.models.llama.modeling_llama import LlamaDecoderLayer from llama_recipes.configs import fsdp_config as FSDP_CONFIG from llama_recipes.configs import train_config as TRAIN_CONFIG from llama_recipes.data.concatenator import ConcatDataset from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing from llama_recipes.utils import fsdp_auto_wrap_policy from llama_recipes.utils.config_utils import ( update_config, generate_peft_config, generate_dataset_config, get_dataloader_kwargs, ) from llama_recipes.utils.dataset_utils import get_preprocessed_dataset from llama_recipes.utils.fsdp_utils import hsdp_device_mesh from llama_recipes.utils.train_utils import ( train, freeze_transformer_layers, setup, setup_environ_flags, clear_gpu_cache, print_model_size, get_policies, ) from accelerate.utils import is_xpu_available def setup_wandb(train_config, fsdp_config, **kwargs): try: import wandb except ImportError: raise ImportError( "You are trying to use wandb which is not currently installed. " "Please install it using pip install wandb" ) from llama_recipes.configs import wandb_config as WANDB_CONFIG wandb_config = WANDB_CONFIG() update_config(wandb_config, **kwargs) init_dict = dataclasses.asdict(wandb_config) run = wandb.init(**init_dict) run.config.update(train_config) run.config.update(fsdp_config, allow_val_change=True) return run def main(**kwargs): # Update the configuration for the training and sharding process train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG() update_config((train_config, fsdp_config), **kwargs) # Set the seeds for reproducibility if is_xpu_available(): torch.xpu.manual_seed(train_config.seed) torch.manual_seed(train_config.seed) random.seed(train_config.seed) if train_config.enable_fsdp: setup() # torchrun specific local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): if is_xpu_available(): torch.xpu.set_device(local_rank) elif torch.cuda.is_available(): torch.cuda.set_device(local_rank) clear_gpu_cache(local_rank) setup_environ_flags(rank) wandb_run = None if train_config.use_wandb: if not train_config.enable_fsdp or rank==0: wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) # Load the pre-trained model and setup its configuration use_cache = False if train_config.enable_fsdp else None if train_config.enable_fsdp and train_config.low_cpu_fsdp: """ for FSDP, we can save cpu memory by loading pretrained model on rank0 only. this avoids cpu oom when loading large models like llama 70B, in which case model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms overhead and currently requires latest nightly. """ v = packaging.version.parse(torch.__version__) verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 if not verify_latest_nightly: raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " "please install latest nightly.") if rank == 0: model = LlamaForCausalLM.from_pretrained( train_config.model_name, load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, use_cache=use_cache, attn_implementation="sdpa" if train_config.use_fast_kernels else None, ) else: llama_config = LlamaConfig.from_pretrained(train_config.model_name) llama_config.use_cache = use_cache with torch.device("meta"): model = LlamaForCausalLM(llama_config) else: model = LlamaForCausalLM.from_pretrained( train_config.model_name, load_in_8bit=True if train_config.quantization else None, device_map="auto" if train_config.quantization else None, use_cache=use_cache, attn_implementation="sdpa" if train_config.use_fast_kernels else None, ) # Load the tokenizer and add special tokens tokenizer = AutoTokenizer.from_pretrained(train_config.model_name) tokenizer.pad_token_id = tokenizer.eos_token_id print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) # Prepare the model for int8 training if quantization is enabled if train_config.quantization: model = prepare_model_for_kbit_training(model) # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled if train_config.enable_fsdp and fsdp_config.pure_bf16: model.to(torch.bfloat16) if train_config.use_peft: peft_config = generate_peft_config(train_config, kwargs) model = get_peft_model(model, peft_config) model.print_trainable_parameters() if wandb_run: wandb_run.config.update(peft_config) hsdp_device_mesh = None if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD: hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size) print("HSDP device mesh is ready") #setting up FSDP if enable_fsdp is enabled if train_config.enable_fsdp: if not train_config.use_peft and train_config.freeze_layers: freeze_transformer_layers(train_config.num_freeze_layers) mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) device_id = 0 if is_xpu_available(): device_id = torch.xpu.current_device() elif torch.cuda.is_available(): device_id = torch.cuda.current_device() model = FSDP( model, auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, sharding_strategy=fsdp_config.sharding_strategy, device_mesh=hsdp_device_mesh, device_id=device_id, limit_all_gathers=True, sync_module_states=train_config.low_cpu_fsdp, param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) if train_config.low_cpu_fsdp and rank != 0 else None, ) if fsdp_config.fsdp_activation_checkpointing: apply_fsdp_checkpointing(model) elif not train_config.quantization and not train_config.enable_fsdp: if is_xpu_available(): model.to("xpu:0") elif torch.cuda.is_available(): model.to("cuda") dataset_config = generate_dataset_config(train_config, kwargs) # Load and preprocess the dataset for training and validation dataset_train = get_preprocessed_dataset( tokenizer, dataset_config, split="train", ) if not train_config.enable_fsdp or rank == 0: print(f"--> Training Set Length = {len(dataset_train)}") dataset_val = get_preprocessed_dataset( tokenizer, dataset_config, split="test", ) if not train_config.enable_fsdp or rank == 0: print(f"--> Validation Set Length = {len(dataset_val)}") if train_config.batching_strategy == "packing": dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train") # Create DataLoaders for the training and validation dataset train_dataloader = torch.utils.data.DataLoader( dataset_train, num_workers=train_config.num_workers_dataloader, pin_memory=True, **train_dl_kwargs, ) eval_dataloader = None if train_config.run_validation: if train_config.batching_strategy == "packing": dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length) val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val") eval_dataloader = torch.utils.data.DataLoader( dataset_val, num_workers=train_config.num_workers_dataloader, pin_memory=True, **val_dl_kwargs, ) # Initialize the optimizer and learning rate scheduler if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": optimizer = AnyPrecisionAdamW( model.parameters(), lr=train_config.lr, momentum_dtype=torch.bfloat16, variance_dtype=torch.bfloat16, use_kahan_summation=False, weight_decay=train_config.weight_decay, ) else: optimizer = optim.AdamW( model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay, ) scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) # Start the training process results = train( model, train_dataloader, eval_dataloader, tokenizer, optimizer, scheduler, train_config.gradient_accumulation_steps, train_config, fsdp_config if train_config.enable_fsdp else None, local_rank if train_config.enable_fsdp else None, rank if train_config.enable_fsdp else None, wandb_run, ) if not train_config.enable_fsdp or rank==0: [print(f'Key: {k}, Value: {v}') for k, v in results.items()] if train_config.use_wandb: for k,v in results.items(): wandb_run.summary[k] = v if __name__ == "__main__": fire.Fire(main)