################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 30-12-2020 #
# Author(s): Diganta Misra #
# E-mail: contact@continualai.org #
# Website: www.continualai.org #
################################################################################
import copy
from typing import TYPE_CHECKING
from torch import Tensor
from avalanche.evaluation import PluginMetric
from avalanche.evaluation.metric_results import MetricValue, MetricResult
from avalanche.evaluation.metric_utils import get_metric_name
if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate
[docs]class WeightCheckpoint(PluginMetric[Tensor]):
"""
The WeightCheckpoint Metric.
Instances of this metric keeps the weight checkpoint tensor of the
model at each experience.
Each time `result` is called, this metric emits the latest experience's
weight checkpoint tensor since the last `reset`.
The reset method will bring the metric to its initial state. By default
this metric in its initial state will return None.
"""
[docs] def __init__(self):
"""
Creates an instance of the WeightCheckpoint Metric.
By default this metric in its initial state will return None.
The metric can be updated by using the `update` method
while the current experience's weight checkpoint tensor can be
retrieved using the `result` method.
"""
super().__init__()
self.weights = None
def update(self, weights) -> Tensor:
"""
Update the weight checkpoint at the current experience.
:param weights: the weight tensor at current experience
:return: None.
"""
self.weights = weights
def result(self) -> Tensor:
"""
Retrieves the weight checkpoint at the current experience.
:return: The weight checkpoint as a tensor.
"""
return self.weights
def reset(self) -> None:
"""
Resets the metric.
:return: None.
"""
self.weights = None
def _package_result(self, strategy) -> "MetricResult":
weights = self.result()
metric_name = get_metric_name(
self, strategy, add_experience=True, add_task=False
)
return [
MetricValue(
self, metric_name, weights, strategy.clock.train_iterations
)
]
def after_training_exp(
self, strategy: "SupervisedTemplate"
) -> "MetricResult":
model_params = copy.deepcopy(strategy.model.parameters())
self.update(model_params)
return None
def __str__(self):
return "WeightCheckpoint"
__all__ = ["WeightCheckpoint"]