Procházet zdrojové kódy

fixing test_gradient_accumulation and test_save_to_json

Kai Wu před 7 měsíci
rodič
revize
362cda0fa6
1 změnil soubory, kde provedl 4 přidání a 2 odebrání
  1. 4 2
      tests/test_train_utils.py

+ 4 - 2
tests/test_train_utils.py

@@ -44,6 +44,8 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     train_config.use_fp16 = False
     train_config.run_validation = False
     train_config.gradient_clipping = False
+    train_config.max_train_step = 0
+    train_config.max_eval_step = 0
     train_config.save_metrics = False
 
     train(
@@ -98,6 +100,8 @@ def test_save_to_json(temp_output_dir, mocker):
     train_config.run_validation = False
     train_config.gradient_clipping = False
     train_config.save_metrics = True
+    train_config.max_train_step = 0
+    train_config.max_eval_step = 0
     train_config.output_dir = temp_output_dir
 
     results = train(
@@ -114,5 +118,3 @@ def test_save_to_json(temp_output_dir, mocker):
 
     assert results["metrics_filename"] not in ["", None]
     assert os.path.isfile(results["metrics_filename"])
-
-