From 9da9155b440d82750ac55a61e8b50a04bb08e510 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 10 Nov 2025 18:05:06 +0000 Subject: [PATCH 1/3] onload from dataloader Signed-off-by: Kyle Sayers --- tests/llmcompressor/pipelines/test_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/llmcompressor/pipelines/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py index c5c10c3cb..ab5381b7a 100644 --- a/tests/llmcompressor/pipelines/test_cache.py +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -156,8 +156,8 @@ def deep_equal(a, b) -> bool: return False return all(deep_equal(a[key], b[key]) for key in a.keys()) case _ if is_dataclass(a): - a_dict = {field.name: getattr(a, field.name) for field in fields(a)} - b_dict = {field.name: getattr(b, field.name) for field in fields(b)} + a_dict = {field: getattr(a, field.name) for field in fields(a)} + b_dict = {field: getattr(b, field.name) for field in fields(b)} return deep_equal(a_dict, b_dict) case _: From 739b86df5a4e83fc2805e1e08697132567e0ccac Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 10 Nov 2025 19:00:27 +0000 Subject: [PATCH 2/3] add support for frozen dataclasses Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/cache.py | 29 +++++++++++---------- tests/llmcompressor/pipelines/test_cache.py | 11 +++++--- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index b647c6824..1beb377e4 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -1,7 +1,7 @@ import sys import warnings from collections import defaultdict -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass, is_dataclass, fields from typing import Any, Dict, Generator, List, Optional, Union import torch @@ -160,8 +160,8 @@ def _size_helper(intermediate: IntermediateValue) -> int: for v in value.values(): _size_helper(v) case _ if is_dataclass(value): - for field in fields(value): - _size_helper(getattr(value, field.name)) + for f in fields(value): + _size_helper(getattr(value, f.name)) case _: # this handles primitive values that don't match any other cases sizes[torch.device("cpu")] += sys.getsizeof(value, 0) @@ -205,10 +205,10 @@ def _onload_value(cls, intermediate: IntermediateValue) -> Any: case dict(): return {k: cls._onload_value(v) for k, v in value.items()} case _ if is_dataclass(value): - for field in fields(value): - v = getattr(value, field.name) - setattr(value, field.name, cls._onload_value(v)) - return value + return type(value)(**{ + f.name: cls._onload_value(getattr(value, f.name)) + for f in fields(value) + }) case _: # handles primitive values that should be returned as is. # without this, a MatchError would be raised for unhandled types. @@ -249,16 +249,17 @@ def _offload_value( ) case dict(): return IntermediateValue( - value={ - k: cls._offload_value(v, **kwargs) for k, v in value.items() - }, + value={k: cls._offload_value(v, **kwargs) for k, v in value.items()}, device=None, ) case _ if is_dataclass(value): - for field in fields(value): - v = getattr(value, field.name) - setattr(value, field.name, cls._offload_value(v, **kwargs)) - return IntermediateValue(value=value, device=None) + return IntermediateValue( + value=type(value)(**{ + f.name: cls._offload_value(getattr(value, f.name), **kwargs) + for f in fields(value) + }), + device=None + ) case _: # handles primitive values and provides a warning for unsupported types. # without this, values trigger a MatchError exception. diff --git a/tests/llmcompressor/pipelines/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py index ab5381b7a..2cca47f1c 100644 --- a/tests/llmcompressor/pipelines/test_cache.py +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -1,3 +1,5 @@ +from typing import Optional + from dataclasses import dataclass, fields, is_dataclass import pytest @@ -7,10 +9,11 @@ from llmcompressor.pipelines.cache import IntermediatesCache -@dataclass +@dataclass(frozen=True) class SampleDataclass: - a: torch.Tensor - b: int + a: int + b: Optional[torch.Tensor] = None + c: Optional["SampleDataclass"] = None @pytest.fixture @@ -35,7 +38,7 @@ def sample_cache(sample_dataloader): values_to_test = [ torch.randn(2, 3).to("cpu"), - SampleDataclass(a=torch.randn(2, 3), b=42), + SampleDataclass(a=42, b=torch.randn(2, 3), c=SampleDataclass(a=64)), torch.float32, [1, 2, 3], ] From a488f21b5f3be9f10a0f6b2fb3c688c9a98e5170 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 18 Dec 2025 06:28:06 +0000 Subject: [PATCH 3/3] apply style Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/cache.py | 28 +++++++++++++-------- tests/llmcompressor/pipelines/test_cache.py | 5 ++-- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index 1beb377e4..4980d7f30 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -1,7 +1,7 @@ import sys import warnings from collections import defaultdict -from dataclasses import dataclass, is_dataclass, fields +from dataclasses import dataclass, fields, is_dataclass from typing import Any, Dict, Generator, List, Optional, Union import torch @@ -205,10 +205,12 @@ def _onload_value(cls, intermediate: IntermediateValue) -> Any: case dict(): return {k: cls._onload_value(v) for k, v in value.items()} case _ if is_dataclass(value): - return type(value)(**{ - f.name: cls._onload_value(getattr(value, f.name)) - for f in fields(value) - }) + return type(value)( + **{ + f.name: cls._onload_value(getattr(value, f.name)) + for f in fields(value) + } + ) case _: # handles primitive values that should be returned as is. # without this, a MatchError would be raised for unhandled types. @@ -249,16 +251,20 @@ def _offload_value( ) case dict(): return IntermediateValue( - value={k: cls._offload_value(v, **kwargs) for k, v in value.items()}, + value={ + k: cls._offload_value(v, **kwargs) for k, v in value.items() + }, device=None, ) case _ if is_dataclass(value): return IntermediateValue( - value=type(value)(**{ - f.name: cls._offload_value(getattr(value, f.name), **kwargs) - for f in fields(value) - }), - device=None + value=type(value)( + **{ + f.name: cls._offload_value(getattr(value, f.name), **kwargs) + for f in fields(value) + } + ), + device=None, ) case _: # handles primitive values and provides a warning for unsupported types. diff --git a/tests/llmcompressor/pipelines/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py index 2cca47f1c..645ccee61 100644 --- a/tests/llmcompressor/pipelines/test_cache.py +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -1,6 +1,5 @@ -from typing import Optional - from dataclasses import dataclass, fields, is_dataclass +from typing import Optional import pytest import torch @@ -12,7 +11,7 @@ @dataclass(frozen=True) class SampleDataclass: a: int - b: Optional[torch.Tensor] = None + b: Optional[torch.Tensor] = None c: Optional["SampleDataclass"] = None