Source code for qumphy.callbacks.progressbar

import lightning as L
import time
import datetime
from typing import Dict, Union


[docs] class EpochProgressBar(L.pytorch.callbacks.ProgressBar): def __init__(self, precision=4, *args, **kwargs): super().__init__(*args, **kwargs) self.precision = precision @L.pytorch.utilities.rank_zero_only def on_train_start( self, trainer: "L.Trainer", pl_module: "L.LightningModule" ) -> None: super().on_train_start(trainer, pl_module) self.training_start = time.time() @L.pytorch.utilities.rank_zero_only def on_train_epoch_start( self, trainer: "L.Trainer", pl_module: "L.LightningModule" ) -> None: self.epoch_start = time.time() @L.pytorch.utilities.rank_zero_only def on_train_epoch_end( self, trainer: "L.Trainer", pl_module: "L.LightningModule" ) -> None: duration = datetime.timedelta(seconds=round(time.time() - self.epoch_start)) print(f"Epoch: {trainer.current_epoch}, Time elapsed: {duration}") print(self._format_metrics(trainer.progress_bar_metrics)) @L.pytorch.utilities.rank_zero_only def on_train_end( self, trainer: "L.Trainer", pl_module: "L.LightningModule" ) -> None: duration = str( datetime.timedelta(seconds=round(time.time() - self.training_start)) ) print(f"Training finished. Time elapsed: {duration}")
[docs] def disable(self) -> None: return
[docs] def enable(self) -> None: return
def _format_metrics(self, metrics: Dict[str, Union[float, Dict]]) -> str: formatted_metrics = "" train_metrics = "" val_metrics = "" for key, value in metrics.items(): if isinstance(value, float): value = f"{value:.{self.precision}f}" if key.startswith("train_"): train_metrics += f"{key.replace('train_', '', 1)}: {value}, " elif key.startswith("val_"): val_metrics += f"{key.replace('val_', '', 1)}: {value}, " formatted_metrics = ( f"TRAINING METRICS: {train_metrics[:-2]}\n" + f"VALIDATION METRICS: {val_metrics[:-2]}" ) return formatted_metrics