Source code for avalanche.logging.wandb_logger

################################################################################
# Copyright (c) 2021 ContinualAI.                                              #
# Copyrights licensed under the MIT License.                                   #
# See the accompanying LICENSE file for terms.                                 #
#                                                                              #
# Date: 25-11-2020                                                             #
# Author(s): Diganta Misra, Andrea Cossu, Lorenzo Pellegrini                   #
# E-mail: contact@continualai.org                                              #
# Website: www.continualai.org                                                 #
################################################################################
""" This module handles all the functionalities related to the logging of
Avalanche experiments using Weights & Biases. """

import re
from typing import Optional, Union, List, TYPE_CHECKING
from pathlib import Path
import os
import warnings

import numpy as np
from numpy import array
from torch import Tensor

from PIL.Image import Image
from matplotlib.pyplot import Figure

from avalanche.core import SupervisedPlugin
from avalanche.evaluation.metric_results import (
    AlternativeValues,
    MetricValue,
    TensorImage,
)
from avalanche.logging import BaseLogger

if TYPE_CHECKING:
    from avalanche.evaluation.metric_results import MetricValue
    from avalanche.training.templates import SupervisedTemplate


CHECKPOINT_METRIC_NAME = re.compile(
    r"^WeightCheckpoint\/(?P<phase_name>\S+)_phase\/(?P<stream_name>\S+)_"
    r"stream(\/Task(?P<task_id>\d+))?\/Exp(?P<experience_id>\d+)$"
)


[docs]class WandBLogger(BaseLogger, SupervisedPlugin): """Weights and Biases logger. The `WandBLogger` provides an easy integration with Weights & Biases logging. Each monitored metric is automatically logged to a dedicated Weights & Biases project dashboard. External storage for W&B Artifacts (for instance - AWS S3 and GCS buckets) uri are supported. The wandb log files are placed by default in "./wandb/" unless specified. .. note:: TensorBoard can be synced on to the W&B dedicated dashboard. """
[docs] def __init__( self, project_name: str = "Avalanche", run_name: str = "Test", log_artifacts: bool = False, path: Union[str, Path] = "Checkpoints", uri: Optional[str] = None, sync_tfboard: bool = False, save_code: bool = True, config: Optional[object] = None, dir: Optional[Union[str, Path]] = None, params: Optional[dict] = None, ): """Creates an instance of the `WandBLogger`. :param project_name: Name of the W&B project. :param run_name: Name of the W&B run. :param log_artifacts: Option to log model weights as W&B Artifacts. Note that, in order for model weights to be logged, the :class:`WeightCheckpoint` metric must be added to the evaluation plugin. :param path: Path to locally save the model checkpoints. :param uri: URI identifier for external storage buckets (GCS, S3). :param sync_tfboard: Syncs TensorBoard to the W&B dashboard UI. :param save_code: Saves the main training script to W&B. :param config: Syncs hyper-parameters and config values used to W&B. :param dir: Path to the local log directory for W&B logs to be saved at. :param params: All arguments for wandb.init() function call. Visit https://docs.wandb.ai/ref/python/init to learn about all wand.init() parameters. """ super().__init__() self.import_wandb() self.project_name = project_name self.run_name = run_name self.log_artifacts = log_artifacts self.path = path self.uri = uri self.sync_tfboard = sync_tfboard self.save_code = save_code self.config = config self.dir = dir self.params = params self.args_parse() self.before_run() self.step = 0 self.exp_count = 0
def import_wandb(self): try: import wandb assert hasattr(wandb, "__version__") except ImportError: raise ImportError('Please run "pip install wandb" to install wandb') self.wandb = wandb def args_parse(self): self.init_kwargs = { "project": self.project_name, "name": self.run_name, "sync_tensorboard": self.sync_tfboard, "dir": self.dir, "save_code": self.save_code, "config": self.config, } if self.params: self.init_kwargs.update(self.params) def before_run(self): if self.wandb is None: self.import_wandb() if self.init_kwargs is None: self.init_kwargs = dict() run_id = self.init_kwargs.get("id", None) if run_id is None: run_id = os.environ.get("WANDB_RUN_ID", None) if run_id is None: run_id = self.wandb.util.generate_id() self.init_kwargs["id"] = run_id self.wandb.init(**self.init_kwargs) self.wandb.run._label(repo="Avalanche") def after_training_exp( self, strategy: "SupervisedTemplate", metric_values: List["MetricValue"], **kwargs, ): for val in metric_values: self.log_metrics([val]) self.wandb.log({"TrainingExperience": self.exp_count}, step=self.step) self.exp_count += 1 def log_single_metric(self, name, value, x_plot): self.step = x_plot if name.startswith("WeightCheckpoint"): if self.log_artifacts: self._log_checkpoint(name, value, x_plot) return if isinstance(value, AlternativeValues): value = value.best_supported_value( Image, Tensor, TensorImage, Figure, float, int, self.wandb.viz.CustomChart, ) if not isinstance( value, ( Image, TensorImage, Tensor, Figure, float, int, self.wandb.viz.CustomChart, ), ): # Unsupported type return if isinstance(value, Image): self.wandb.log({name: self.wandb.Image(value)}, step=self.step) elif isinstance(value, Tensor): value = np.histogram(value.view(-1).numpy()) self.wandb.log( {name: self.wandb.Histogram(np_histogram=value)}, step=self.step ) elif isinstance(value, (float, int, Figure, self.wandb.viz.CustomChart)): self.wandb.log({name: value}, step=self.step) elif isinstance(value, TensorImage): self.wandb.log({name: self.wandb.Image(array(value))}, step=self.step) def _log_checkpoint(self, name, value, x_plot): assert self.wandb is not None # Example: 'WeightCheckpoint/train_phase/train_stream/Task000/Exp000' name_match = CHECKPOINT_METRIC_NAME.match(name) if name_match is None: warnings.warn(f"Checkpoint metric has unsupported name {name}.") return # phase_name: str = name_match['phase_name'] # stream_name: str = name_match['stream_name'] task_id: Optional[int] = ( int(name_match["task_id"]) if name_match["task_id"] is not None else None ) experience_id: int = int(name_match["experience_id"]) assert experience_id >= 0 cwd = Path.cwd() checkpoint_directory = cwd / self.path checkpoint_directory.mkdir(parents=True, exist_ok=True) checkpoint_name = "Model_{}".format(experience_id) checkpoint_file_name = checkpoint_name + ".pth" checkpoint_path = checkpoint_directory / checkpoint_file_name artifact_name = "Models/" + checkpoint_file_name # Write the checkpoint blob with open(checkpoint_path, "wb") as f: f.write(value) metadata = { "experience": experience_id, "x_step": x_plot, **({"task_id": task_id} if task_id is not None else {}), } artifact = self.wandb.Artifact(checkpoint_name, type="model", metadata=metadata) artifact.add_file(str(checkpoint_path), name=artifact_name) self.wandb.run.log_artifact(artifact) if self.uri is not None: artifact.add_reference(self.uri, name=artifact_name) def __getstate__(self): state = self.__dict__.copy() if "wandb" in state: del state["wandb"] return state def __setstate__(self, state): print("[W&B logger] Resuming from checkpoint...") self.__dict__ = state if self.init_kwargs is None: self.init_kwargs = dict() self.init_kwargs["resume"] = "allow" self.wandb = None self.before_run()
__all__ = ["WandBLogger"]