소스 검색

Fix tqdm bar not change length after terminal is resized

hongbo.mo 1 년 전
부모
커밋
6217635e87
2개의 변경된 파일4개의 추가작업 그리고 4개의 파일을 삭제
  1. 2 2
      src/llama_recipes/datasets/utils.py
  2. 2 2
      src/llama_recipes/utils/train_utils.py

+ 2 - 2
src/llama_recipes/datasets/utils.py

@@ -52,7 +52,7 @@ class ConcatDataset(Dataset):
             "labels": [],
             }
         
-        for sample in tqdm(self.dataset, desc="Preprocessing dataset"):
+        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
             buffer = {k: v + sample[k] for k,v in buffer.items()}
             
             while len(next(iter(buffer.values()))) > self.chunk_size:
@@ -63,4 +63,4 @@ class ConcatDataset(Dataset):
         return self.samples[idx]
     
     def __len__(self):
-        return len(self.samples)
+        return len(self.samples)

+ 2 - 2
src/llama_recipes/utils/train_utils.py

@@ -69,7 +69,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             model.train()
             total_loss = 0.0
             total_length = len(train_dataloader)//gradient_accumulation_steps
-            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length)
+            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
             for step, batch in enumerate(train_dataloader):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
@@ -227,7 +227,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
     with MemoryTrace() as memtrace:
-        for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
+        for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
             for key in batch.keys():
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)