@@ -66,7 +66,6 @@ import torch
import torch.nn as nn
import torch.cuda.nccl as nccl
import torch.distributed as dist
-from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
from transformers.models.llama.modeling_llama import LlamaDecoderLayer