|
@@ -86,7 +86,8 @@ def main(**kwargs):
|
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
|
|
|
|
if torch.distributed.is_initialized():
|
|
|
- torch.cuda.set_device(rank)
|
|
|
+ torch.cuda.set_device(local_rank)
|
|
|
+ clear_gpu_cache(local_rank)
|
|
|
setup_environ_flags(rank)
|
|
|
|
|
|
# Calculate gradient accumulation steps
|