test_batching.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import pytest
  4. from unittest.mock import patch
  5. @patch('llama_recipes.finetuning.train')
  6. @patch('llama_recipes.finetuning.LlamaTokenizer')
  7. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  8. @patch('llama_recipes.finetuning.optim.AdamW')
  9. @patch('llama_recipes.finetuning.StepLR')
  10. def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
  11. from llama_recipes.finetuning import main
  12. setup_tokenizer(tokenizer)
  13. kwargs = {
  14. "model_name": "decapoda-research/llama-7b-hf",
  15. "batch_size_training": 8,
  16. "val_batch_size": 1,
  17. "use_peft": False,
  18. "dataset": "samsum_dataset",
  19. "batching_strategy": "packing",
  20. }
  21. main(**kwargs)
  22. assert train.call_count == 1
  23. args, kwargs = train.call_args
  24. train_dataloader = args[1]
  25. eval_dataloader = args[2]
  26. assert len(train_dataloader) == 96
  27. assert len(eval_dataloader) == 42
  28. batch = next(iter(train_dataloader))
  29. assert "labels" in batch.keys()
  30. assert "input_ids" in batch.keys()
  31. assert "attention_mask" in batch.keys()
  32. assert batch["labels"][0].size(0) == 4096
  33. assert batch["input_ids"][0].size(0) == 4096
  34. assert batch["attention_mask"][0].size(0) == 4096
  35. @patch('llama_recipes.finetuning.train')
  36. @patch('llama_recipes.finetuning.LlamaTokenizer')
  37. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  38. @patch('llama_recipes.finetuning.optim.AdamW')
  39. @patch('llama_recipes.finetuning.StepLR')
  40. @patch('llama_recipes.finetuning.setup')
  41. @patch('llama_recipes.finetuning.FSDP')
  42. @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
  43. @patch('llama_recipes.utils.config_utils.dist')
  44. def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
  45. import os
  46. from llama_recipes.finetuning import main
  47. setup_tokenizer(tokenizer)
  48. rank = 0
  49. os.environ['LOCAL_RANK'] = f'{rank}'
  50. os.environ['RANK'] = f'{rank}'
  51. os.environ['WORLD_SIZE'] = '2'
  52. os.environ['MASTER_ADDR'] = 'localhost'
  53. os.environ['MASTER_PORT'] = '12345'
  54. kwargs = {
  55. "model_name": "decapoda-research/llama-7b-hf",
  56. "batch_size_training": 8,
  57. "val_batch_size": 1,
  58. "use_peft": False,
  59. "dataset": "samsum_dataset",
  60. "batching_strategy": "packing",
  61. "enable_fsdp": True
  62. }
  63. is_initialized.return_value = True
  64. dist.get_rank.return_value = rank
  65. dist.get_world_size.return_value = 2
  66. main(**kwargs)
  67. assert train.call_count == 1
  68. args, kwargs = train.call_args
  69. train_dataloader = args[1]
  70. eval_dataloader = args[2]
  71. assert len(train_dataloader) == 96 //2
  72. assert len(eval_dataloader) == 42 //2