|
@@ -66,7 +66,6 @@ import torch
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
import torch.cuda.nccl as nccl
|
|
import torch.cuda.nccl as nccl
|
|
import torch.distributed as dist
|
|
import torch.distributed as dist
|
|
-from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
|
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
|
|
|
|
|
|