finetuning.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. from pkg_resources import packaging
  5. import dataclasses
  6. import fire
  7. import random
  8. import torch
  9. import torch.optim as optim
  10. from peft import get_peft_model, prepare_model_for_kbit_training
  11. from torch.distributed.fsdp import (
  12. FullyShardedDataParallel as FSDP,
  13. ShardingStrategy
  14. )
  15. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  16. from torch.optim.lr_scheduler import StepLR
  17. from transformers import (
  18. AutoTokenizer,
  19. LlamaForCausalLM,
  20. LlamaConfig,
  21. )
  22. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  23. from llama_recipes.configs import fsdp_config as FSDP_CONFIG
  24. from llama_recipes.configs import train_config as TRAIN_CONFIG
  25. from llama_recipes.data.concatenator import ConcatDataset
  26. from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  27. from llama_recipes.utils import fsdp_auto_wrap_policy
  28. from llama_recipes.utils.config_utils import (
  29. update_config,
  30. generate_peft_config,
  31. generate_dataset_config,
  32. get_dataloader_kwargs,
  33. )
  34. from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
  35. from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
  36. from llama_recipes.utils.train_utils import (
  37. train,
  38. freeze_transformer_layers,
  39. setup,
  40. setup_environ_flags,
  41. clear_gpu_cache,
  42. print_model_size,
  43. get_policies,
  44. )
  45. from accelerate.utils import is_xpu_available
  46. def setup_wandb(train_config, fsdp_config, **kwargs):
  47. try:
  48. import wandb
  49. except ImportError:
  50. raise ImportError(
  51. "You are trying to use wandb which is not currently installed. "
  52. "Please install it using pip install wandb"
  53. )
  54. from llama_recipes.configs import wandb_config as WANDB_CONFIG
  55. wandb_config = WANDB_CONFIG()
  56. update_config(wandb_config, **kwargs)
  57. init_dict = dataclasses.asdict(wandb_config)
  58. run = wandb.init(**init_dict)
  59. run.config.update(train_config)
  60. run.config.update(fsdp_config, allow_val_change=True)
  61. return run
  62. def main(**kwargs):
  63. # Update the configuration for the training and sharding process
  64. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  65. update_config((train_config, fsdp_config), **kwargs)
  66. # Set the seeds for reproducibility
  67. if is_xpu_available():
  68. torch.xpu.manual_seed(train_config.seed)
  69. torch.manual_seed(train_config.seed)
  70. random.seed(train_config.seed)
  71. if train_config.enable_fsdp:
  72. setup()
  73. # torchrun specific
  74. local_rank = int(os.environ["LOCAL_RANK"])
  75. rank = int(os.environ["RANK"])
  76. world_size = int(os.environ["WORLD_SIZE"])
  77. if torch.distributed.is_initialized():
  78. if is_xpu_available():
  79. torch.xpu.set_device(local_rank)
  80. elif torch.cuda.is_available():
  81. torch.cuda.set_device(local_rank)
  82. clear_gpu_cache(local_rank)
  83. setup_environ_flags(rank)
  84. wandb_run = None
  85. if train_config.use_wandb:
  86. if not train_config.enable_fsdp or rank==0:
  87. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  88. # Load the pre-trained model and setup its configuration
  89. use_cache = False if train_config.enable_fsdp else None
  90. if train_config.enable_fsdp and train_config.low_cpu_fsdp:
  91. """
  92. for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
  93. this avoids cpu oom when loading large models like llama 70B, in which case
  94. model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
  95. overhead and currently requires latest nightly.
  96. """
  97. v = packaging.version.parse(torch.__version__)
  98. verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
  99. if not verify_latest_nightly:
  100. raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
  101. "please install latest nightly.")
  102. if rank == 0:
  103. model = LlamaForCausalLM.from_pretrained(
  104. train_config.model_name,
  105. load_in_8bit=True if train_config.quantization else None,
  106. device_map="auto" if train_config.quantization else None,
  107. use_cache=use_cache,
  108. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  109. )
  110. else:
  111. llama_config = LlamaConfig.from_pretrained(train_config.model_name)
  112. llama_config.use_cache = use_cache
  113. with torch.device("meta"):
  114. model = LlamaForCausalLM(llama_config)
  115. else:
  116. model = LlamaForCausalLM.from_pretrained(
  117. train_config.model_name,
  118. load_in_8bit=True if train_config.quantization else None,
  119. device_map="auto" if train_config.quantization else None,
  120. use_cache=use_cache,
  121. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  122. )
  123. # Load the tokenizer and add special tokens
  124. tokenizer = AutoTokenizer.from_pretrained(train_config.model_name)
  125. tokenizer.pad_token_id = tokenizer.eos_token_id
  126. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  127. # Prepare the model for int8 training if quantization is enabled
  128. if train_config.quantization:
  129. model = prepare_model_for_kbit_training(model)
  130. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  131. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  132. model.to(torch.bfloat16)
  133. if train_config.use_peft:
  134. peft_config = generate_peft_config(train_config, kwargs)
  135. model = get_peft_model(model, peft_config)
  136. model.print_trainable_parameters()
  137. if wandb_run:
  138. wandb_run.config.update(peft_config)
  139. hsdp_device_mesh = None
  140. if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
  141. hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
  142. print("HSDP device mesh is ready")
  143. #setting up FSDP if enable_fsdp is enabled
  144. if train_config.enable_fsdp:
  145. if not train_config.use_peft and train_config.freeze_layers:
  146. freeze_transformer_layers(train_config.num_freeze_layers)
  147. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  148. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  149. device_id = 0
  150. if is_xpu_available():
  151. device_id = torch.xpu.current_device()
  152. elif torch.cuda.is_available():
  153. device_id = torch.cuda.current_device()
  154. model = FSDP(
  155. model,
  156. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  157. cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
  158. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  159. sharding_strategy=fsdp_config.sharding_strategy,
  160. device_mesh=hsdp_device_mesh,
  161. device_id=device_id,
  162. limit_all_gathers=True,
  163. sync_module_states=train_config.low_cpu_fsdp,
  164. param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
  165. if train_config.low_cpu_fsdp and rank != 0 else None,
  166. )
  167. if fsdp_config.fsdp_activation_checkpointing:
  168. apply_fsdp_checkpointing(model)
  169. elif not train_config.quantization and not train_config.enable_fsdp:
  170. if is_xpu_available():
  171. model.to("xpu:0")
  172. elif torch.cuda.is_available():
  173. model.to("cuda")
  174. dataset_config = generate_dataset_config(train_config, kwargs)
  175. # Load and preprocess the dataset for training and validation
  176. dataset_train = get_preprocessed_dataset(
  177. tokenizer,
  178. dataset_config,
  179. split="train",
  180. )
  181. if not train_config.enable_fsdp or rank == 0:
  182. print(f"--> Training Set Length = {len(dataset_train)}")
  183. dataset_val = get_preprocessed_dataset(
  184. tokenizer,
  185. dataset_config,
  186. split="test",
  187. )
  188. if not train_config.enable_fsdp or rank == 0:
  189. print(f"--> Validation Set Length = {len(dataset_val)}")
  190. if train_config.batching_strategy == "packing":
  191. dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
  192. train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
  193. # Create DataLoaders for the training and validation dataset
  194. train_dataloader = torch.utils.data.DataLoader(
  195. dataset_train,
  196. num_workers=train_config.num_workers_dataloader,
  197. pin_memory=True,
  198. **train_dl_kwargs,
  199. )
  200. eval_dataloader = None
  201. if train_config.run_validation:
  202. if train_config.batching_strategy == "packing":
  203. dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
  204. val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
  205. eval_dataloader = torch.utils.data.DataLoader(
  206. dataset_val,
  207. num_workers=train_config.num_workers_dataloader,
  208. pin_memory=True,
  209. **val_dl_kwargs,
  210. )
  211. # Initialize the optimizer and learning rate scheduler
  212. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  213. optimizer = AnyPrecisionAdamW(
  214. model.parameters(),
  215. lr=train_config.lr,
  216. momentum_dtype=torch.bfloat16,
  217. variance_dtype=torch.bfloat16,
  218. use_kahan_summation=False,
  219. weight_decay=train_config.weight_decay,
  220. )
  221. else:
  222. optimizer = optim.AdamW(
  223. model.parameters(),
  224. lr=train_config.lr,
  225. weight_decay=train_config.weight_decay,
  226. )
  227. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  228. # Start the training process
  229. results = train(
  230. model,
  231. train_dataloader,
  232. eval_dataloader,
  233. tokenizer,
  234. optimizer,
  235. scheduler,
  236. train_config.gradient_accumulation_steps,
  237. train_config,
  238. fsdp_config if train_config.enable_fsdp else None,
  239. local_rank if train_config.enable_fsdp else None,
  240. rank if train_config.enable_fsdp else None,
  241. wandb_run,
  242. )
  243. if not train_config.enable_fsdp or rank==0:
  244. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  245. if train_config.use_wandb:
  246. for k,v in results.items():
  247. wandb_run.summary[k] = v
  248. if __name__ == "__main__":
  249. fire.Fire(main)