|
@@ -4,6 +4,7 @@
|
|
|
import os
|
|
|
from pkg_resources import packaging
|
|
|
|
|
|
+import dataclasses
|
|
|
import fire
|
|
|
import random
|
|
|
import torch
|
|
@@ -55,9 +56,9 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
)
|
|
|
from llama_recipes.configs import wandb_config as WANDB_CONFIG
|
|
|
wandb_config = WANDB_CONFIG()
|
|
|
- wandb_entity = None if wandb_config.wandb_entity == 'none' else wandb_config.wandb_entity
|
|
|
update_config(wandb_config, **kwargs)
|
|
|
- run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity)
|
|
|
+ init_dict = dataclasses.asdict(wandb_config)
|
|
|
+ run = wandb.init(**init_dict)
|
|
|
run.config.update(train_config)
|
|
|
run.config.update(fsdp_config, allow_val_change=True)
|
|
|
return run
|
|
@@ -86,7 +87,7 @@ def main(**kwargs):
|
|
|
|
|
|
wandb_run = None
|
|
|
|
|
|
- if train_config.enable_wandb:
|
|
|
+ if train_config.use_wandb:
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
|
|
|
@@ -265,7 +266,7 @@ def main(**kwargs):
|
|
|
)
|
|
|
if not train_config.enable_fsdp or rank==0:
|
|
|
[print(f'Key: {k}, Value: {v}') for k, v in results.items()]
|
|
|
- if train_config.enable_wandb:
|
|
|
+ if train_config.use_wandb:
|
|
|
for k,v in results.items():
|
|
|
wandb_run.summary[k] = v
|
|
|
|