|
@@ -30,8 +30,9 @@ from llama_recipes.utils.tflop_counter import FlopCounterMode
|
|
|
@contextlib.contextmanager
|
|
|
def maybe_run_profiler(cfg, *args, **kwargs):
|
|
|
use_profiler: bool = cfg.profiler
|
|
|
-
|
|
|
+
|
|
|
if use_profiler:
|
|
|
+ print(f"profiling is activated and results will be saved in {cfg.profile_output_dir}")
|
|
|
with torch.profiler.profile(
|
|
|
activities=[
|
|
|
torch.profiler.ProfilerActivity.CPU,
|
|
@@ -594,4 +595,4 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
|
|
|
"val_epoch_perplexity": val_epoch_ppl
|
|
|
}
|
|
|
with open(output_filename, "w") as f:
|
|
|
- json.dump(metrics_data, f)
|
|
|
+ json.dump(metrics_data, f)
|