@@ -148,7 +148,7 @@ def annotate(self, mmif: Union[str, dict, Mmif], **runtime_params: List[str]) ->
148148 pretty = refined .get ('pretty' , False )
149149 t = datetime .now ()
150150 with warnings .catch_warnings (record = True ) as ws :
151- annotated = self ._annotate (mmif , ** refined )
151+ annotated , cuda_profiler = self ._profile_cuda_memory ( self . _annotate ) (mmif , ** refined )
152152 if ws :
153153 issued_warnings .extend (ws )
154154 if issued_warnings :
@@ -164,11 +164,21 @@ def annotate(self, mmif: Union[str, dict, Mmif], **runtime_params: List[str]) ->
164164 runtime_recs ['architecture' ] = platform .machine ()
165165 # runtime_recs['processor'] = platform.processor() # this only works on Windows
166166 runtime_recs ['cuda' ] = []
167- if shutil .which ('nvidia-smi' ):
167+ # Use cuda_profiler data if available, otherwise fallback to nvidia-smi
168+ if cuda_profiler :
169+ for gpu_info , peak_memory_bytes in cuda_profiler .items ():
170+ # Convert peak memory to human-readable format
171+ peak_memory_mb = peak_memory_bytes / (1000 * 1000 )
172+ if peak_memory_mb >= 1000 :
173+ peak_memory_str = f"{ peak_memory_mb / 1000 :.2f} GiB"
174+ else :
175+ peak_memory_str = f"{ peak_memory_mb :.1f} MiB"
176+ runtime_recs ['cuda' ].append (f"{ gpu_info } , Used { self ._cuda_memory_to_str (peak_memory_bytes )} " )
177+ elif shutil .which ('nvidia-smi' ):
168178 for gpu in subprocess .run (['nvidia-smi' , '--query-gpu=name,memory.total' , '--format=csv,noheader' ],
169179 stdout = subprocess .PIPE ).stdout .decode ('utf-8' ).strip ().split ('\n ' ):
170180 name , mem = gpu .split (', ' )
171- runtime_recs ['cuda' ].append (f' { name } ( { mem } )' )
181+ runtime_recs ['cuda' ].append (self . _cuda_device_name_concat ( name , mem ) )
172182 for annotated_view in annotated .views :
173183 if annotated_view .metadata .app == self .metadata .identifier :
174184 if runningTime :
@@ -321,6 +331,66 @@ def validate_document_locations(mmif: Union[str, Mmif]) -> None:
321331 # (https://github.com/clamsproject/mmif/issues/150) , here is a good place for additional check for
322332 # file integrity
323333
334+ @staticmethod
335+ def _cuda_memory_to_str (mem ) -> str :
336+ mib = mem / (1024 * 1024 )
337+ if mib >= 1024 :
338+ return f"{ mib / 1024 :.2f} GiB"
339+ else :
340+ return f"{ mib :.1f} MiB"
341+
342+ @staticmethod
343+ def _cuda_device_name_concat (name , mem ):
344+ if type (mem ) in (bytes , int ):
345+ mem = ClamsApp ._cuda_memory_to_str (mem )
346+ return f"{ name } , With { mem } "
347+
348+ @staticmethod
349+ def _profile_cuda_memory (func ):
350+ """
351+ Decorator for profiling CUDA memory usage during _annotate execution.
352+
353+ :param func: The function to wrap (typically _annotate)
354+ :return: Decorated function that returns (result, cuda_profiler)
355+ where cuda_profiler is dict with "<GPU_NAME>, <GPU_TOTAL_MEMORY>" keys
356+ and peak memory usage values
357+ """
358+ def wrapper (* args , ** kwargs ):
359+ cuda_profiler = {}
360+ torch_available = False
361+ cuda_available = False
362+ device_count = 0
363+
364+ try :
365+ import torch # pytype: disable=import-error
366+ torch_available = True
367+ cuda_available = torch .cuda .is_available ()
368+ device_count = torch .cuda .device_count ()
369+ if cuda_available :
370+ # Reset peak memory stats for all devices
371+ torch .cuda .reset_peak_memory_stats ('cuda' )
372+ except ImportError :
373+ pass
374+
375+ try :
376+ result = func (* args , ** kwargs )
377+
378+ if torch_available and cuda_available and device_count > 0 :
379+ for device_id in range (device_count ):
380+ device_id = f'cuda:{ device_id } '
381+ peak_memory = torch .cuda .max_memory_allocated (device_id )
382+ gpu_name = torch .cuda .get_device_name (device_id )
383+ gpu_total_memory = torch .cuda .get_device_properties (device_id ).total_memory
384+ key = ClamsApp ._cuda_device_name_concat (gpu_name , gpu_total_memory )
385+ cuda_profiler [key ] = peak_memory
386+
387+ return result , cuda_profiler
388+ finally :
389+ if torch_available and cuda_available :
390+ torch .cuda .empty_cache ()
391+
392+ return wrapper
393+
324394 @staticmethod
325395 @contextmanager
326396 def open_document_location (document : Union [str , Document ], opener : Any = open , ** openerargs ):
0 commit comments