################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 19-01-2021 #
# Author(s): Vincenzo Lomonaco, Lorenzo Pellegrini #
# E-mail: contact@continualai.org #
# Website: www.continualai.org #
################################################################################
import os
import time
from typing import Optional, List, TYPE_CHECKING
from threading import Thread
from psutil import Process
from avalanche.evaluation import Metric, PluginMetric, GenericPluginMetric
from avalanche.evaluation.metric_results import MetricResult
if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate
[docs]class MaxRAM(Metric[float]):
"""The standalone RAM usage metric.
Important: this metric approximates the real maximum RAM usage since
it sample at discrete amount of time the RAM values.
Instances of this metric keeps the maximum RAM usage detected.
The `start_thread` method starts the usage tracking.
The `stop_thread` method stops the tracking.
The result, obtained using the `result` method, is the usage in mega-bytes.
The reset method will bring the metric to its initial state. By default
this metric in its initial state will return an usage value of 0.
"""
[docs] def __init__(self, every=1):
"""Creates an instance of the RAM usage metric.
:param every: seconds after which update the maximum RAM usage.
"""
self._process_handle: Optional[Process] = Process(os.getpid())
"""The process handle, lazily initialized."""
self.every = every
self.stop_f = False
"""Flag to stop the thread."""
self.max_usage = 0
"""
Main metric result. Max RAM usage.
"""
self.thread = None
"""
Thread executing RAM monitoring code
"""
def _f(self):
"""
Until a stop signal is encountered,
this function monitors each `every` seconds
the maximum amount of RAM used by the process
"""
start_time = time.monotonic()
while not self.stop_f:
# ram usage in MB
ram_usage = self._process_handle.memory_info().rss / 1024 / 1024
if ram_usage > self.max_usage:
self.max_usage = ram_usage
time.sleep(
self.every - ((time.monotonic() - start_time) % self.every)
)
def result(self) -> Optional[float]:
"""
Retrieves the RAM usage.
Calling this method will not change the internal state of the metric.
:return: The average RAM usage in bytes, as a float value.
"""
return self.max_usage
def start_thread(self):
assert not self.thread, (
"Trying to start thread " "without joining the previous."
)
self.thread = Thread(target=self._f, daemon=True)
self.thread.start()
def stop_thread(self):
if self.thread:
self.stop_f = True
self.thread.join()
self.stop_f = False
self.thread = None
def reset(self) -> None:
"""
Resets the metric.
:return: None.
"""
self.max_usage = 0
def update(self):
pass
class RAMPluginMetric(GenericPluginMetric[float]):
def __init__(self, every, reset_at, emit_at, mode):
self._ram = MaxRAM(every)
super(RAMPluginMetric, self).__init__(
self._ram, reset_at, emit_at, mode
)
def update(self, strategy):
self._ram.update()
[docs]class MinibatchMaxRAM(RAMPluginMetric):
"""The Minibatch Max RAM metric.
This plugin metric only works at training time.
"""
[docs] def __init__(self, every=1):
"""Creates an instance of the Minibatch Max RAM metric.
:param every: seconds after which update the maximum RAM
usage
"""
super(MinibatchMaxRAM, self).__init__(
every, reset_at="iteration", emit_at="iteration", mode="train"
)
def before_training(self, strategy: "SupervisedTemplate") -> None:
super().before_training(strategy)
self._ram.start_thread()
def after_training(self, strategy: "SupervisedTemplate") -> None:
super().after_training(strategy)
self._ram.stop_thread()
def __str__(self):
return "MaxRAMUsage_MB"
[docs]class EpochMaxRAM(RAMPluginMetric):
"""The Epoch Max RAM metric.
This plugin metric only works at training time.
"""
[docs] def __init__(self, every=1):
"""Creates an instance of the epoch Max RAM metric.
:param every: seconds after which update the maximum RAM usage.
"""
super(EpochMaxRAM, self).__init__(
every, reset_at="epoch", emit_at="epoch", mode="train"
)
def before_training(self, strategy: "SupervisedTemplate") -> None:
super().before_training(strategy)
self._ram.start_thread()
def after_training(self, strategy: "SupervisedTemplate") -> None:
super().before_training(strategy)
self._ram.stop_thread()
def __str__(self):
return "MaxRAMUsage_Epoch"
[docs]class ExperienceMaxRAM(RAMPluginMetric):
"""The Experience Max RAM metric.
This plugin metric only works at eval time.
"""
[docs] def __init__(self, every=1):
"""Creates an instance of the Experience CPU usage metric.
:param every: seconds after which update the maximum RAM usage.
"""
super(ExperienceMaxRAM, self).__init__(
every, reset_at="experience", emit_at="experience", mode="eval"
)
def before_eval(self, strategy: "SupervisedTemplate") -> None:
super().before_eval(strategy)
self._ram.start_thread()
def after_eval(self, strategy: "SupervisedTemplate") -> None:
super().after_eval(strategy)
self._ram.stop_thread()
def __str__(self):
return "MaxRAMUsage_Experience"
[docs]class StreamMaxRAM(RAMPluginMetric):
"""The Stream Max RAM metric.
This plugin metric only works at eval time.
"""
[docs] def __init__(self, every=1):
"""Creates an instance of the Experience CPU usage metric.
:param every: seconds after which update the maximum RAM usage.
"""
super(StreamMaxRAM, self).__init__(
every, reset_at="stream", emit_at="stream", mode="eval"
)
def before_eval(self, strategy):
super().before_eval(strategy)
self._ram.start_thread()
def after_eval(self, strategy: "SupervisedTemplate") -> MetricResult:
packed = super().after_eval(strategy)
self._ram.stop_thread()
return packed
def __str__(self):
return "MaxRAMUsage_Stream"
[docs]def ram_usage_metrics(
*, every=1, minibatch=False, epoch=False, experience=False, stream=False
) -> List[PluginMetric]:
"""Helper method that can be used to obtain the desired set of
plugin metrics.
:param every: seconds after which update the maximum RAM
usage
:param minibatch: If True, will return a metric able to log the minibatch
max RAM usage.
:param epoch: If True, will return a metric able to log the epoch
max RAM usage.
:param experience: If True, will return a metric able to log the experience
max RAM usage.
:param stream: If True, will return a metric able to log the evaluation
max stream RAM usage.
:return: A list of plugin metrics.
"""
metrics = []
if minibatch:
metrics.append(MinibatchMaxRAM(every=every))
if epoch:
metrics.append(EpochMaxRAM(every=every))
if experience:
metrics.append(ExperienceMaxRAM(every=every))
if stream:
metrics.append(StreamMaxRAM(every=every))
return metrics
__all__ = [
"MaxRAM",
"MinibatchMaxRAM",
"EpochMaxRAM",
"ExperienceMaxRAM",
"StreamMaxRAM",
"ram_usage_metrics",
]