Browse Source

Merge remote-tracking branch 'origin/main' into feature/length_based_batch_sampling

Matthias Reso 1 year ago
parent
commit
e8bb7fbabc

File diff suppressed because it is too large
+ 1 - 0
examples/Getting_to_know_Llama.ipynb


+ 3 - 6
examples/quickstart.ipynb

@@ -32,7 +32,7 @@
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
     "# %%bash\n",
     "# %%bash\n",
-    "# pip install transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
+    "# pip install llama-recipes transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
     "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n",
     "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n",
     "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B"
     "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B"
    ]
    ]
@@ -130,11 +130,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "from pathlib import Path\n",
-    "import os\n",
-    "import sys\n",
-    "from utils.dataset_utils import get_preprocessed_dataset\n",
-    "from configs.datasets import samsum_dataset\n",
+    "from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n",
+    "from llama_recipes.configs.datasets import samsum_dataset\n",
     "\n",
     "\n",
     "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')"
     "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')"
    ]
    ]

+ 3 - 3
src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb

@@ -35,10 +35,10 @@
     "  (\" '\", \"'\"),\n",
     "  (\" '\", \"'\"),\n",
     "  (\" ?\", \"?\"),\n",
     "  (\" ?\", \"?\"),\n",
     "  (\" !\", \"!\"),\n",
     "  (\" !\", \"!\"),\n",
-    "  (\" :\", \"!\"),\n",
-    "  (\" ;\", \"!\"),\n",
+    "  (\" :\", \":\"),\n",
+    "  (\" ;\", \";\"),\n",
     "  (\" n't\", \"n't\"),\n",
     "  (\" n't\", \"n't\"),\n",
-    "  (\" v\", \"n't\"),\n",
+    "  (\" v\", \"v\"),\n",
     "  (\"2 0 0 6\", \"2006\"),\n",
     "  (\"2 0 0 6\", \"2006\"),\n",
     "  (\"5 5\", \"55\"),\n",
     "  (\"5 5\", \"55\"),\n",
     "  (\"4 0 0\", \"400\"),\n",
     "  (\"4 0 0\", \"400\"),\n",

+ 38 - 34
src/llama_recipes/utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import os
 import time
 import time
 import yaml
 import yaml
+from contextlib import nullcontext
 from pathlib import Path
 from pathlib import Path
 from pkg_resources import packaging
 from pkg_resources import packaging
 
 
@@ -25,7 +26,7 @@ from llama_recipes.utils.memory_utils import MemoryTrace
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
     tokenizer.padding_side = "left"
-    
+
 # Converting Bytes to Megabytes
 # Converting Bytes to Megabytes
 def byte2mb(x):
 def byte2mb(x):
     return int(x / 2**20)
     return int(x / 2**20)
@@ -33,7 +34,7 @@ def byte2mb(x):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
     """
     """
     Trains the model on the given dataloader
     Trains the model on the given dataloader
-    
+
     Args:
     Args:
         model: The model to be trained
         model: The model to be trained
         train_dataloader: The dataloader containing the training data
         train_dataloader: The dataloader containing the training data
@@ -45,16 +46,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_config: The training configuration
         train_config: The training configuration
         eval_dataloader: The dataloader containing the eval data
         eval_dataloader: The dataloader containing the eval data
         tokenizer: tokenizer used in the eval for decoding the predicitons
         tokenizer: tokenizer used in the eval for decoding the predicitons
-    
+
     Returns: results dictionary containing average training and validation perplexity and loss
     Returns: results dictionary containing average training and validation perplexity and loss
     """
     """
     # Create a gradient scaler for fp16
     # Create a gradient scaler for fp16
     if train_config.use_fp16 and train_config.enable_fsdp:
     if train_config.use_fp16 and train_config.enable_fsdp:
         scaler = ShardedGradScaler()
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
     elif train_config.use_fp16 and not train_config.enable_fsdp:
-        scaler = torch.cuda.amp.GradScaler() 
+        scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
+    autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
+
     train_prep = []
     train_prep = []
     train_loss = []
     train_loss = []
     val_prep = []
     val_prep = []
@@ -76,7 +79,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         batch[key] = batch[key].to(local_rank)
                         batch[key] = batch[key].to(local_rank)
                     else:
                     else:
                         batch[key] = batch[key].to('cuda:0')
                         batch[key] = batch[key].to('cuda:0')
-                loss = model(**batch).loss
+                with autocast():
+                    loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
                 total_loss += loss.detach().float()
                 if train_config.use_fp16:
                 if train_config.use_fp16:
@@ -97,9 +101,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()
             pbar.close()
-                
+
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_end_time = time.perf_counter()-epoch_start_time
-        epoch_times.append(epoch_end_time)    
+        epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -107,10 +111,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.enable_fsdp:
         if train_config.enable_fsdp:
             train_epoch_loss = train_epoch_loss/world_size
             train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
         train_perplexity = torch.exp(train_epoch_loss)
-        
+
         train_prep.append(train_perplexity)
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
         train_loss.append(train_epoch_loss)
-        
+
         if train_config.enable_fsdp:
         if train_config.enable_fsdp:
             if rank==0:
             if rank==0:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
@@ -124,10 +128,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
             print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
             print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        
+
         # Update the learning rate as needed
         # Update the learning rate as needed
         lr_scheduler.step()
         lr_scheduler.step()
-          
+
         if train_config.run_validation:
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             checkpoint_start_time = time.perf_counter()
             checkpoint_start_time = time.perf_counter()
@@ -140,23 +144,23 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             print(f"we are about to save the PEFT modules")
                             print(f"we are about to save the PEFT modules")
                     else:
                     else:
                         print(f"we are about to save the PEFT modules")
                         print(f"we are about to save the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)  
+                    model.save_pretrained(train_config.output_dir)
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
-                        if rank==0: 
+                        if rank==0:
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                     else:
                     else:
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                        
+
                 else:
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                        
+
                         save_model_checkpoint(
                         save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
                         print("=====================================================")
-                        
+
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                         if train_config.save_optimizer:
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
@@ -168,7 +172,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             model, optimizer, rank, train_config, epoch=epoch
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
-                        print("=====================================================")                     
+                        print("=====================================================")
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
                     dist.barrier()
                     dist.barrier()
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
@@ -192,8 +196,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep)/len(val_prep) 
-        avg_eval_loss = sum(val_loss)/len(val_loss) 
+        avg_eval_prep = sum(val_prep)/len(val_prep)
+        avg_eval_loss = sum(val_loss)/len(val_loss)
 
 
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_loss'] = avg_train_loss
     results['avg_train_loss'] = avg_train_loss
@@ -202,27 +206,27 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
-    
+
     #saving the training params including fsdp setting for reference.
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
     if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)
         save_train_params(train_config, fsdp_config, rank)
-        
+
     return results
     return results
 
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     """
     """
     Evaluates the model on the given dataloader
     Evaluates the model on the given dataloader
-    
+
     Args:
     Args:
         model: The model to evaluate
         model: The model to evaluate
         eval_dataloader: The dataloader containing the evaluation data
         eval_dataloader: The dataloader containing the evaluation data
         local_rank: The rank of the current node in a distributed setting
         local_rank: The rank of the current node in a distributed setting
         tokenizer: The tokenizer used to decode predictions
         tokenizer: The tokenizer used to decode predictions
-    
+
     Returns: eval_ppl, eval_epoch_loss
     Returns: eval_ppl, eval_epoch_loss
     """
     """
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
     model.eval()
     model.eval()
     eval_preds = []
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
     eval_loss = 0.0  # Initialize evaluation loss
@@ -244,24 +248,24 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             eval_preds.extend(
             eval_preds.extend(
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
             )
             )
-    
+
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
-    
+
     # Compute average loss and perplexity
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         eval_epoch_loss = eval_epoch_loss/world_size
         eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
     eval_ppl = torch.exp(eval_epoch_loss)
-    
+
     # Print evaluation metrics
     # Print evaluation metrics
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         if local_rank==0:
         if local_rank==0:
             print(f" {eval_ppl=} {eval_epoch_loss=}")
             print(f" {eval_ppl=} {eval_epoch_loss=}")
     else:
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
         print(f" {eval_ppl=} {eval_epoch_loss=}")
-        
+
     return eval_ppl, eval_epoch_loss
     return eval_ppl, eval_epoch_loss
 
 
 def freeze_transformer_layers(model, num_layer):
 def freeze_transformer_layers(model, num_layer):
@@ -275,8 +279,8 @@ def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
      for i, layer in enumerate(model.base_model.model.model.layers):
             for name, param in layer.named_parameters():
             for name, param in layer.named_parameters():
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
-                
-                
+
+
 def setup():
 def setup():
     """Initialize the process group for distributed training"""
     """Initialize the process group for distributed training"""
     dist.init_process_group("nccl")
     dist.init_process_group("nccl")
@@ -289,7 +293,7 @@ def setup_environ_flags(rank):
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
-    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
     if rank == 0:
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
         print(f"--> Running with torch dist debug set to detail")
 
 
@@ -334,7 +338,7 @@ def print_model_size(model, config, rank: int = 0) -> None:
 
 
 def get_policies(cfg, rank):
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
     """Get the policies for mixed precision and fsdp wrapping"""
-    
+
     verify_bfloat_support = (
     verify_bfloat_support = (
     torch.version.cuda
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
     and torch.cuda.is_bf16_supported()
@@ -370,7 +374,7 @@ def save_train_params(train_config, fsdp_config, rank):
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     It also would be hepful as a log for future references.
     It also would be hepful as a log for future references.
     """
     """
-    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # Convert the train_config and fsdp_config objects to dictionaries,
     # converting all values to strings to ensure they can be serialized into a YAML file
     # converting all values to strings to ensure they can be serialized into a YAML file
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}

+ 1 - 1
tests/test_finetuning.py

@@ -96,7 +96,7 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
+def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
     kwargs = {"weight_decay": 0.01}
     kwargs = {"weight_decay": 0.01}
 
 
     get_dataset.return_value = get_fake_dataset()
     get_dataset.return_value = get_fake_dataset()

+ 18 - 5
tests/test_train_utils.py

@@ -1,17 +1,22 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+from unittest.mock import patch
+
 import torch
 import torch
 
 
 from llama_recipes.utils.train_utils import train
 from llama_recipes.utils.train_utils import train
 
 
-def test_gradient_accumulation(mocker):
-    # import sys
-    # sys.path.append('/home/ubuntu/llama-recipes/')
+@patch("llama_recipes.utils.train_utils.MemoryTrace")
+@patch("llama_recipes.utils.train_utils.nullcontext")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
+def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
     
     
     model = mocker.MagicMock(name="model")
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
-    batch = {"input": torch.zeros(1)}
+    mock_tensor = mocker.MagicMock(name="tensor")
+    batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
     train_dataloader = [batch, batch, batch, batch, batch]
     eval_dataloader = None
     eval_dataloader = None
     tokenizer = mocker.MagicMock()
     tokenizer = mocker.MagicMock()
@@ -37,7 +42,13 @@ def test_gradient_accumulation(mocker):
     assert optimizer.zero_grad.call_count == 5
     assert optimizer.zero_grad.call_count == 5
     optimizer.zero_grad.reset_mock()
     optimizer.zero_grad.reset_mock()
     
     
+    assert nullcontext.call_count == 5
+    nullcontext.reset_mock()
+    
+    assert autocast.call_count == 0
+    
     gradient_accumulation_steps = 2
     gradient_accumulation_steps = 2
+    train_config.use_fp16 = True
     train(
     train(
         model,
         model,
         train_dataloader,
         train_dataloader,
@@ -48,4 +59,6 @@ def test_gradient_accumulation(mocker):
         gradient_accumulation_steps,
         gradient_accumulation_steps,
         train_config,
         train_config,
     )
     )
-    assert optimizer.zero_grad.call_count == 3
+    assert optimizer.zero_grad.call_count == 3
+    assert nullcontext.call_count == 0
+    assert autocast.call_count == 5