|
@@ -100,8 +100,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
|
|
|
with maybe_run_profiler(train_config) as torch_profiler:
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
- if step > 5:
|
|
|
- break
|
|
|
gc.collect(1)
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|
|
@@ -288,8 +286,6 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
eval_loss = 0.0 # Initialize evaluation loss
|
|
|
with MemoryTrace() as memtrace:
|
|
|
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
|
|
|
- if step > 5:
|
|
|
- break
|
|
|
gc.collect(1)
|
|
|
for key in batch.keys():
|
|
|
if train_config.enable_fsdp:
|