|
@@ -226,12 +226,13 @@ def main(**kwargs):
|
|
|
momentum_dtype=torch.bfloat16,
|
|
|
variance_dtype=torch.bfloat16,
|
|
|
use_kahan_summation=False,
|
|
|
+ weight_decay=train_config.weight_decay,
|
|
|
)
|
|
|
else:
|
|
|
optimizer = optim.AdamW(
|
|
|
model.parameters(),
|
|
|
lr=train_config.lr,
|
|
|
- weight_decay=0.0,
|
|
|
+ weight_decay=train_config.weight_decay,
|
|
|
)
|
|
|
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
|
|
|
|