Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions torchrec/metrics/cpu_comms_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ def _load_metric_states(
Uses aggregated states.
"""

# All update() calls were done prior. Clear previous computed state.
# Otherwise, we get warnings that compute() was called before
# update() which is not the case.
computation = cast(RecMetricComputation, computation)
set_update_called(computation)
computation._computed = None
Expand Down Expand Up @@ -157,8 +154,9 @@ def _clone_rec_metrics(self) -> RecMetricList:

def set_update_called(computation: RecMetricComputation) -> None:
"""
Set _update_called to True for RecMetricComputation.
This is a workaround for torchmetrics 1.0.3+.
All update() calls were done prior. Clear previous computed state.
Otherwise, we get warnings that compute() was called before
update() which is not the case.
"""
try:
computation._update_called = True
Expand Down
96 changes: 57 additions & 39 deletions torchrec/metrics/cpu_offloaded_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
MetricUpdateJob,
SynchronizationMarker,
)
from torchrec.metrics.metric_module import MetricValue, RecMetricModule
from torchrec.metrics.metric_module import MetricsFuture, MetricsResult, RecMetricModule
from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot
from torchrec.metrics.model_utils import parse_task_model_outputs
from torchrec.metrics.rec_metric import RecMetricException
Expand Down Expand Up @@ -62,19 +62,27 @@ class CPUOffloadedRecMetricModule(RecMetricModule):

def __init__(
self,
device: torch.device,
update_queue_size: int = 100,
compute_queue_size: int = 100,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
All arguments are the same as RecMetricModule except for
- update_queue_size: Maximum size of the update queue. Default is 100.
- compute_queue_size: Maximum size of the update queue. Default is 100.
batch_size: batch size used by this trainer.
world_size: the number of trainers.
device: the device where the model is located (used to determine whether to perform GPU to CPU transfers).
update_queue_size: Maximum size of the update queue. Default is 100.
compute_queue_size: Maximum size of the update queue. Default is 100.
*args: Additional positional arguments passed to RecMetricModule.
**kwargs: Additional keyword arguments passed to RecMetricModule.
"""
super().__init__(*args, **kwargs)
self._shutdown_event = threading.Event()
self._device = device
self._shutdown_event: threading.Event = threading.Event()
self._captured_exception_event: threading.Event = threading.Event()
self._captured_exception: Optional[Exception] = None

self.update_queue: queue.Queue[
Union[MetricUpdateJob, SynchronizationMarker]
Expand Down Expand Up @@ -132,8 +140,16 @@ def _update_rec_metrics(
if self._shutdown_event.is_set():
raise RecMetricException("metric processor thread is shut down.")

if self._captured_exception_event.is_set():
assert self._captured_exception is not None
raise self._captured_exception

try:
cpu_model_out, transfer_completed_event = self._transfer_to_cpu(model_out)
cpu_model_out, transfer_completed_event = (
self._transfer_to_cpu(model_out)
if self._device == torch.device("cuda")
else (model_out, None)
)
self.update_queue.put_nowait(
MetricUpdateJob(
model_out=cpu_model_out,
Expand Down Expand Up @@ -191,31 +207,25 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None
"""

with record_function("## CPUOffloadedRecMetricModule:update ##"):
try:
if metric_update_job.transfer_completed_event is not None:
metric_update_job.transfer_completed_event.synchronize()
labels, predictions, weights, required_inputs = (
parse_task_model_outputs(
self.rec_tasks,
metric_update_job.model_out,
self.get_required_inputs(),
)
)
if required_inputs:
metric_update_job.kwargs["required_inputs"] = required_inputs

self.rec_metrics.update(
predictions=predictions,
labels=labels,
weights=weights,
**metric_update_job.kwargs,
)

if self.throughput_metric:
self.throughput_metric.update()
labels, predictions, weights, required_inputs = parse_task_model_outputs(
self.rec_tasks,
metric_update_job.model_out,
self.get_required_inputs(),
)
if required_inputs:
metric_update_job.kwargs["required_inputs"] = required_inputs

self.rec_metrics.update(
predictions=predictions,
labels=labels,
weights=weights,
**metric_update_job.kwargs,
)

except Exception as e:
logger.exception("Error processing metric update: %s", e)
raise e
if self.throughput_metric:
self.throughput_metric.update()

@override
def shutdown(self) -> None:
Expand Down Expand Up @@ -248,30 +258,34 @@ def shutdown(self) -> None:
logger.info("CPUOffloadedRecMetricModule has been successfully shutdown.")

@override
def compute(self) -> Dict[str, MetricValue]:
def compute(self) -> MetricsResult:
raise RecMetricException(
"compute() is not supported in CPUOffloadedRecMetricModule. Use async_compute() instead."
"CPUOffloadedRecMetricModule does not support compute(). Use async_compute() instead."
)

@override
def async_compute(
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
) -> None:
def async_compute(self) -> MetricsFuture:
"""
Entry point for asynchronous metric compute. It enqueues a synchronization marker
to the update queue.

Args:
Returns:
future: Pre-created future where the computed metrics will be set.
"""
metrics_future = concurrent.futures.Future()
if self._shutdown_event.is_set():
future.set_exception(
metrics_future.set_exception(
RecMetricException("metric processor thread is shut down.")
)
return
return metrics_future

if self._captured_exception_event.is_set():
assert self._captured_exception is not None
raise self._captured_exception

self.update_queue.put_nowait(SynchronizationMarker(future))
self.update_queue.put_nowait(SynchronizationMarker(metrics_future))
self.update_queue_size_logger.add(self.update_queue.qsize())
return metrics_future

def _process_synchronization_marker(
self, synchronization_marker: SynchronizationMarker
Expand Down Expand Up @@ -304,7 +318,7 @@ def _process_synchronization_marker(

def _process_metric_compute_job(
self, metric_compute_job: MetricComputeJob
) -> Dict[str, MetricValue]:
) -> MetricsResult:
"""
Process a metric compute job:
1. Comms module performs all gather
Expand Down Expand Up @@ -355,6 +369,8 @@ def _update_loop(self) -> None:
self._do_work(self.update_queue)
except Exception as e:
logger.exception(f"Exception in update loop: {e}")
self._captured_exception_event.set()
self._captured_exception = e
raise e

remaining = self._flush_remaining_work(self.update_queue)
Expand All @@ -372,6 +388,8 @@ def _compute_loop(self) -> None:
self._do_work(self.compute_queue)
except Exception as e:
logger.exception(f"Exception in compute loop: {e}")
self._captured_exception_event.set()
self._captured_exception = e
raise e

remaining = self._flush_remaining_work(self.compute_queue)
Expand Down
8 changes: 5 additions & 3 deletions torchrec/metrics/metric_job_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pyre-strict

import concurrent
from typing import Any, Dict
from typing import Any, Dict, Optional

import torch
from torchrec.metrics.metric_module import MetricValue
Expand All @@ -26,7 +26,7 @@ class MetricUpdateJob:
def __init__(
self,
model_out: Dict[str, torch.Tensor],
transfer_completed_event: torch.cuda.Event,
transfer_completed_event: Optional[torch.cuda.Event],
kwargs: Dict[str, Any],
) -> None:
"""
Expand All @@ -37,7 +37,9 @@ def __init__(
"""

self.model_out: Dict[str, torch.Tensor] = model_out
self.transfer_completed_event: torch.cuda.Event = transfer_completed_event
self.transfer_completed_event: Optional[torch.cuda.Event] = (
transfer_completed_event
)
self.kwargs: Dict[str, Any] = kwargs


Expand Down
Loading
Loading