import lightning as L
import time
import datetime
from typing import Dict, Union
[docs]
class EpochProgressBar(L.pytorch.callbacks.ProgressBar):
def __init__(self, precision=4, *args, **kwargs):
super().__init__(*args, **kwargs)
self.precision = precision
@L.pytorch.utilities.rank_zero_only
def on_train_start(
self, trainer: "L.Trainer", pl_module: "L.LightningModule"
) -> None:
super().on_train_start(trainer, pl_module)
self.training_start = time.time()
@L.pytorch.utilities.rank_zero_only
def on_train_epoch_start(
self, trainer: "L.Trainer", pl_module: "L.LightningModule"
) -> None:
self.epoch_start = time.time()
@L.pytorch.utilities.rank_zero_only
def on_train_epoch_end(
self, trainer: "L.Trainer", pl_module: "L.LightningModule"
) -> None:
duration = datetime.timedelta(seconds=round(time.time() - self.epoch_start))
print(f"Epoch: {trainer.current_epoch}, Time elapsed: {duration}")
print(self._format_metrics(trainer.progress_bar_metrics))
@L.pytorch.utilities.rank_zero_only
def on_train_end(
self, trainer: "L.Trainer", pl_module: "L.LightningModule"
) -> None:
duration = str(
datetime.timedelta(seconds=round(time.time() - self.training_start))
)
print(f"Training finished. Time elapsed: {duration}")
[docs]
def disable(self) -> None:
return
[docs]
def enable(self) -> None:
return
def _format_metrics(self, metrics: Dict[str, Union[float, Dict]]) -> str:
formatted_metrics = ""
train_metrics = ""
val_metrics = ""
for key, value in metrics.items():
if isinstance(value, float):
value = f"{value:.{self.precision}f}"
if key.startswith("train_"):
train_metrics += f"{key.replace('train_', '', 1)}: {value}, "
elif key.startswith("val_"):
val_metrics += f"{key.replace('val_', '', 1)}: {value}, "
formatted_metrics = (
f"TRAINING METRICS: {train_metrics[:-2]}\n"
+ f"VALIDATION METRICS: {val_metrics[:-2]}"
)
return formatted_metrics