test_batching.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import pytest
  4. from unittest.mock import patch
  5. @pytest.mark.skip_missing_tokenizer
  6. @patch('llama_recipes.finetuning.train')
  7. @patch('llama_recipes.finetuning.LlamaTokenizer')
  8. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  9. @patch('llama_recipes.finetuning.optim.AdamW')
  10. @patch('llama_recipes.finetuning.StepLR')
  11. def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
  12. from llama_recipes.finetuning import main
  13. setup_tokenizer(tokenizer)
  14. kwargs = {
  15. "model_name": "meta-llama/Llama-2-7b-hf",
  16. "batch_size_training": 8,
  17. "val_batch_size": 1,
  18. "use_peft": False,
  19. "dataset": "samsum_dataset",
  20. "batching_strategy": "packing",
  21. }
  22. main(**kwargs)
  23. assert train.call_count == 1
  24. args, kwargs = train.call_args
  25. train_dataloader = args[1]
  26. eval_dataloader = args[2]
  27. assert len(train_dataloader) == 96
  28. assert len(eval_dataloader) == 42
  29. batch = next(iter(train_dataloader))
  30. assert "labels" in batch.keys()
  31. assert "input_ids" in batch.keys()
  32. assert "attention_mask" in batch.keys()
  33. assert batch["labels"][0].size(0) == 4096
  34. assert batch["input_ids"][0].size(0) == 4096
  35. assert batch["attention_mask"][0].size(0) == 4096
  36. @pytest.mark.skip_missing_tokenizer
  37. @patch('llama_recipes.finetuning.train')
  38. @patch('llama_recipes.finetuning.LlamaTokenizer')
  39. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  40. @patch('llama_recipes.finetuning.optim.AdamW')
  41. @patch('llama_recipes.finetuning.StepLR')
  42. @patch('llama_recipes.finetuning.setup')
  43. @patch('llama_recipes.finetuning.FSDP')
  44. @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
  45. @patch('llama_recipes.utils.config_utils.dist')
  46. def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
  47. import os
  48. from llama_recipes.finetuning import main
  49. setup_tokenizer(tokenizer)
  50. rank = 0
  51. os.environ['LOCAL_RANK'] = f'{rank}'
  52. os.environ['RANK'] = f'{rank}'
  53. os.environ['WORLD_SIZE'] = '2'
  54. os.environ['MASTER_ADDR'] = 'localhost'
  55. os.environ['MASTER_PORT'] = '12345'
  56. kwargs = {
  57. "model_name": "meta-llama/Llama-2-7b-hf",
  58. "batch_size_training": 8,
  59. "val_batch_size": 1,
  60. "use_peft": False,
  61. "dataset": "samsum_dataset",
  62. "batching_strategy": "packing",
  63. "enable_fsdp": True
  64. }
  65. is_initialized.return_value = True
  66. dist.get_rank.return_value = rank
  67. dist.get_world_size.return_value = 2
  68. main(**kwargs)
  69. assert train.call_count == 1
  70. args, kwargs = train.call_args
  71. train_dataloader = args[1]
  72. eval_dataloader = args[2]
  73. assert len(train_dataloader) == 96 //2
  74. assert len(eval_dataloader) == 42 //2