Skip to content

wandb

optimus_dl.modules.loggers.wandb

Weights & Biases (wandb) metrics logger implementation.

This logger integrates with Weights & Biases for experiment tracking, supporting both online and offline modes.

WandbLogger

Bases: BaseMetricsLogger

Weights & Biases metrics logger.

Logs training metrics, configuration, and optionally model artifacts to Weights & Biases for experiment tracking and visualization.

Supports resuming runs by storing and reloading the WandB run_id from the training state dict.

Source code in optimus_dl/modules/loggers/wandb.py
@register_metrics_logger("wandb", WandbLoggerConfig)
class WandbLogger(BaseMetricsLogger):
    """Weights & Biases metrics logger.

    Logs training metrics, configuration, and optionally model artifacts
    to Weights & Biases for experiment tracking and visualization.

    Supports resuming runs by storing and reloading the WandB `run_id` from
    the training state dict.
    """

    def __init__(self, cfg: WandbLoggerConfig, state_dict=None, **kwargs):
        """Initialize WandB logger.

        Args:
            cfg: WandB logger configuration.
            state_dict: Optional state containing 'run_id' for resuming.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(cfg, **kwargs)

        if not WANDB_AVAILABLE:
            self.enabled = False
            logger.error("WandB logger disabled - wandb package not available")
            return

        if cfg.mode == "disabled":
            self.enabled = False
            logger.info("WandB logger disabled via mode setting")
            return

        self.run_id = (state_dict or {}).get("run_id")
        self.run = None
        self.logs_parent_path = None

    def setup(
        self,
        experiment_name: str,
        config: dict[str, Any],
        logs_parent_path: str | None = None,
        start_iteration: int | None = None,
    ) -> None:
        """Initialize a WandB run with experiment metadata and configuration.

        If `self.run_id` is present, attempts to resume the existing run.
        """
        if not self.enabled:
            return

        import wandb

        name = self.cfg.name or experiment_name

        try:
            if self.run_id is not None and self.cfg.mode == "online":
                entity = self.cfg.entity or os.getenv("WANDB_ENTITY")
                project = self.cfg.project or os.getenv("WANDB_PROJECT")

                run_path = f"{self.run_id}"
                if project is not None and entity is not None:
                    run_path = f"{entity}/{project}/{self.run_id}"
                elif project is not None:
                    run_path = f"{project}/{self.run_id}"

                api = wandb.Api()
                try:
                    run: wandb.Run = api.run(run_path)
                    if run.name != name:
                        logger.warning(
                            f"Wandb run name does not match the loaded experiment name: {name} (this run) != {run.name} ({self.run_id}). Launching a new wandb run."
                        )
                        self.run_id = None
                except ValueError as e:
                    logger.warning(
                        f"Could not load wandb run {run_path}: {e}. Launching a new wandb run."
                    )

            # Initialize wandb run
            if OmegaConf.is_config(config):
                config = OmegaConf.to_container(config, resolve=True)

            self.run = wandb.init(
                project=self.cfg.project,
                entity=self.cfg.entity,
                mode=self.cfg.mode,
                name=name,
                group=self.cfg.group,
                job_type=self.cfg.job_type,
                tags=list(self.cfg.tags) if self.cfg.tags else None,
                notes=self.cfg.notes,
                save_code=self.cfg.save_code,
                config=config,
                id=self.run_id,
                resume="allow",
            )
            self.logs_parent_path = logs_parent_path

            # Configure "iteration" as the global step metric so it is used as the x-axis.
            try:
                wandb.define_metric("iteration")
                wandb.define_metric("*", step_metric="iteration")
            except Exception:
                # Older wandb versions or unexpected errors: fall back without breaking logging.
                logger.debug(
                    "Failed to define WandB step metric 'iteration'", exc_info=True
                )

            logger.info(f"WandB run initialized: {self.run.name} ({self.run.id})")
        except Exception as e:
            logger.error(f"Failed to initialize WandB: {e}", exc_info=True)
            self.enabled = False

    def log_metrics(
        self, metrics: dict[str, Any], step: int, group: str = "train"
    ) -> None:
        """Flatten and log metrics to WandB.

        Args:
            metrics: Dictionary of metric names to values.
            step: Training step/iteration number.
            group: Metrics group (e.g., 'train', 'eval').
        """
        if not self.enabled:
            return

        if self.run is None:
            logger.warning("WandB run not initialized, skipping metrics logging")
            return

        try:
            # Flatten nested metrics and add group prefix
            flattened_metrics = {}

            for key, value in metrics.items():
                if isinstance(value, dict):
                    # Handle nested metrics
                    for nested_key, nested_value in value.items():
                        full_key = f"{group}/{key}/{nested_key}"
                        flattened_metrics[full_key] = nested_value
                else:
                    # Simple metric
                    full_key = f"{group}/{key}"
                    flattened_metrics[full_key] = value

            # Log to wandb
            flattened_metrics["step"] = step
            flattened_metrics["iteration"] = step
            self.run.log(flattened_metrics)

        except Exception as e:
            logger.error(f"Failed to log metrics to WandB: {e}")

    def close(self) -> None:
        """Finalize and close the WandB run."""
        if self.run is not None:
            try:
                if self.logs_parent_path is not None:
                    self.run.log_artifact(
                        artifact_or_path=self.logs_parent_path, type="logs"
                    )
            except Exception as e:
                logger.error(f"Error saving logs as artifacts: {e}")
            try:
                self.run.finish()
                logger.info("WandB run finished successfully")
            except Exception as e:
                logger.error(f"Error finishing WandB run: {e}")
            finally:
                self.run = None

    def state_dict(self):
        """Return the current run ID for resuming later."""
        return {
            "run_id": self.run.id if self.run is not None else None,
        }

    def finished(self, status: RunStatus):
        """Set the run status tag in WandB at the end of the run."""
        if not self.enabled or self.run is None:
            return

        try:
            self.run.summary["finish_status"] = status.value
            logger.info(f"WandB run finished with status: {status.value}")
        except Exception as e:
            logger.error(f"Failed to set run status tag in WandB: {e}")

__init__(cfg, state_dict=None, **kwargs)

Initialize WandB logger.

Parameters:

Name Type Description Default
cfg WandbLoggerConfig

WandB logger configuration.

required
state_dict

Optional state containing 'run_id' for resuming.

None
**kwargs

Additional keyword arguments.

{}
Source code in optimus_dl/modules/loggers/wandb.py
def __init__(self, cfg: WandbLoggerConfig, state_dict=None, **kwargs):
    """Initialize WandB logger.

    Args:
        cfg: WandB logger configuration.
        state_dict: Optional state containing 'run_id' for resuming.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(cfg, **kwargs)

    if not WANDB_AVAILABLE:
        self.enabled = False
        logger.error("WandB logger disabled - wandb package not available")
        return

    if cfg.mode == "disabled":
        self.enabled = False
        logger.info("WandB logger disabled via mode setting")
        return

    self.run_id = (state_dict or {}).get("run_id")
    self.run = None
    self.logs_parent_path = None

close()

Finalize and close the WandB run.

Source code in optimus_dl/modules/loggers/wandb.py
def close(self) -> None:
    """Finalize and close the WandB run."""
    if self.run is not None:
        try:
            if self.logs_parent_path is not None:
                self.run.log_artifact(
                    artifact_or_path=self.logs_parent_path, type="logs"
                )
        except Exception as e:
            logger.error(f"Error saving logs as artifacts: {e}")
        try:
            self.run.finish()
            logger.info("WandB run finished successfully")
        except Exception as e:
            logger.error(f"Error finishing WandB run: {e}")
        finally:
            self.run = None

finished(status)

Set the run status tag in WandB at the end of the run.

Source code in optimus_dl/modules/loggers/wandb.py
def finished(self, status: RunStatus):
    """Set the run status tag in WandB at the end of the run."""
    if not self.enabled or self.run is None:
        return

    try:
        self.run.summary["finish_status"] = status.value
        logger.info(f"WandB run finished with status: {status.value}")
    except Exception as e:
        logger.error(f"Failed to set run status tag in WandB: {e}")

log_metrics(metrics, step, group='train')

Flatten and log metrics to WandB.

Parameters:

Name Type Description Default
metrics dict[str, Any]

Dictionary of metric names to values.

required
step int

Training step/iteration number.

required
group str

Metrics group (e.g., 'train', 'eval').

'train'
Source code in optimus_dl/modules/loggers/wandb.py
def log_metrics(
    self, metrics: dict[str, Any], step: int, group: str = "train"
) -> None:
    """Flatten and log metrics to WandB.

    Args:
        metrics: Dictionary of metric names to values.
        step: Training step/iteration number.
        group: Metrics group (e.g., 'train', 'eval').
    """
    if not self.enabled:
        return

    if self.run is None:
        logger.warning("WandB run not initialized, skipping metrics logging")
        return

    try:
        # Flatten nested metrics and add group prefix
        flattened_metrics = {}

        for key, value in metrics.items():
            if isinstance(value, dict):
                # Handle nested metrics
                for nested_key, nested_value in value.items():
                    full_key = f"{group}/{key}/{nested_key}"
                    flattened_metrics[full_key] = nested_value
            else:
                # Simple metric
                full_key = f"{group}/{key}"
                flattened_metrics[full_key] = value

        # Log to wandb
        flattened_metrics["step"] = step
        flattened_metrics["iteration"] = step
        self.run.log(flattened_metrics)

    except Exception as e:
        logger.error(f"Failed to log metrics to WandB: {e}")

setup(experiment_name, config, logs_parent_path=None, start_iteration=None)

Initialize a WandB run with experiment metadata and configuration.

If self.run_id is present, attempts to resume the existing run.

Source code in optimus_dl/modules/loggers/wandb.py
def setup(
    self,
    experiment_name: str,
    config: dict[str, Any],
    logs_parent_path: str | None = None,
    start_iteration: int | None = None,
) -> None:
    """Initialize a WandB run with experiment metadata and configuration.

    If `self.run_id` is present, attempts to resume the existing run.
    """
    if not self.enabled:
        return

    import wandb

    name = self.cfg.name or experiment_name

    try:
        if self.run_id is not None and self.cfg.mode == "online":
            entity = self.cfg.entity or os.getenv("WANDB_ENTITY")
            project = self.cfg.project or os.getenv("WANDB_PROJECT")

            run_path = f"{self.run_id}"
            if project is not None and entity is not None:
                run_path = f"{entity}/{project}/{self.run_id}"
            elif project is not None:
                run_path = f"{project}/{self.run_id}"

            api = wandb.Api()
            try:
                run: wandb.Run = api.run(run_path)
                if run.name != name:
                    logger.warning(
                        f"Wandb run name does not match the loaded experiment name: {name} (this run) != {run.name} ({self.run_id}). Launching a new wandb run."
                    )
                    self.run_id = None
            except ValueError as e:
                logger.warning(
                    f"Could not load wandb run {run_path}: {e}. Launching a new wandb run."
                )

        # Initialize wandb run
        if OmegaConf.is_config(config):
            config = OmegaConf.to_container(config, resolve=True)

        self.run = wandb.init(
            project=self.cfg.project,
            entity=self.cfg.entity,
            mode=self.cfg.mode,
            name=name,
            group=self.cfg.group,
            job_type=self.cfg.job_type,
            tags=list(self.cfg.tags) if self.cfg.tags else None,
            notes=self.cfg.notes,
            save_code=self.cfg.save_code,
            config=config,
            id=self.run_id,
            resume="allow",
        )
        self.logs_parent_path = logs_parent_path

        # Configure "iteration" as the global step metric so it is used as the x-axis.
        try:
            wandb.define_metric("iteration")
            wandb.define_metric("*", step_metric="iteration")
        except Exception:
            # Older wandb versions or unexpected errors: fall back without breaking logging.
            logger.debug(
                "Failed to define WandB step metric 'iteration'", exc_info=True
            )

        logger.info(f"WandB run initialized: {self.run.name} ({self.run.id})")
    except Exception as e:
        logger.error(f"Failed to initialize WandB: {e}", exc_info=True)
        self.enabled = False

state_dict()

Return the current run ID for resuming later.

Source code in optimus_dl/modules/loggers/wandb.py
def state_dict(self):
    """Return the current run ID for resuming later."""
    return {
        "run_id": self.run.id if self.run is not None else None,
    }

WandbLoggerConfig dataclass

Bases: MetricsLoggerConfig

Configuration for Weights & Biases logger.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
project str | None
None
entity str | None
None
mode str
'online'
save_code bool
True
group str | None
None
job_type str | None
'train'
name str | None
None
Source code in optimus_dl/modules/loggers/wandb.py
@dataclass
class WandbLoggerConfig(MetricsLoggerConfig):
    """Configuration for Weights & Biases logger.

    Attributes:
        project: Name of the WandB project.
        entity: WandB entity (user or team) to log to.
        mode: WandB run mode: 'online', 'offline', or 'disabled'.
        save_code: If True, saves the main script and its dependencies.
        group: Name for grouping related runs.
        job_type: Label for the type of run (e.g., 'train', 'eval').
        name: Display name for the run. If None, uses experiment name.
    """

    # WandB specific settings
    project: str | None = None
    entity: str | None = None
    mode: str = "online"  # "online", "offline", or "disabled"
    save_code: bool = True

    # Run configuration
    group: str | None = None
    job_type: str | None = "train"
    name: str | None = None