lchu 1 год назад
Родитель
Сommit
0c51b47262
1 измененных файлов с 2 добавлено и 1 удалено
  1. 2 1
      llama_finetuning.py

+ 2 - 1
llama_finetuning.py

@@ -86,7 +86,8 @@ def main(**kwargs):
         world_size = int(os.environ["WORLD_SIZE"])
         world_size = int(os.environ["WORLD_SIZE"])
 
 
     if torch.distributed.is_initialized():
     if torch.distributed.is_initialized():
-        torch.cuda.set_device(rank)
+        torch.cuda.set_device(local_rank)
+        clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
         setup_environ_flags(rank)
 
 
     # Calculate gradient accumulation steps
     # Calculate gradient accumulation steps