Browse Source

fixing test_gradient_accumulation and test_save_to_json

Kai Wu 7 months ago
parent
commit
362cda0fa6
1 changed files with 4 additions and 2 deletions
  1. 4 2
      tests/test_train_utils.py

+ 4 - 2
tests/test_train_utils.py

@@ -44,6 +44,8 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     train_config.use_fp16 = False
     train_config.use_fp16 = False
     train_config.run_validation = False
     train_config.run_validation = False
     train_config.gradient_clipping = False
     train_config.gradient_clipping = False
+    train_config.max_train_step = 0
+    train_config.max_eval_step = 0
     train_config.save_metrics = False
     train_config.save_metrics = False
 
 
     train(
     train(
@@ -98,6 +100,8 @@ def test_save_to_json(temp_output_dir, mocker):
     train_config.run_validation = False
     train_config.run_validation = False
     train_config.gradient_clipping = False
     train_config.gradient_clipping = False
     train_config.save_metrics = True
     train_config.save_metrics = True
+    train_config.max_train_step = 0
+    train_config.max_eval_step = 0
     train_config.output_dir = temp_output_dir
     train_config.output_dir = temp_output_dir
 
 
     results = train(
     results = train(
@@ -114,5 +118,3 @@ def test_save_to_json(temp_output_dir, mocker):
 
 
     assert results["metrics_filename"] not in ["", None]
     assert results["metrics_filename"] not in ["", None]
     assert os.path.isfile(results["metrics_filename"])
     assert os.path.isfile(results["metrics_filename"])
-
-