Skip to content

Commit d3aba7d

Browse files
committed
[Feature] Support callable scale in IndependentNormal and TanhNormal distributions
Allow the scale parameter to be a callable (e.g., torch.ones_like) for torch.compile friendliness. This enables more flexible distribution construction without needing to pre-compute scale tensors. ghstack-source-id: e1136a9 Pull-Request: #3296
1 parent 5c75777 commit d3aba7d

4 files changed

Lines changed: 253 additions & 10 deletions

File tree

test/test_distributions.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import argparse
88
import importlib.util
9+
from functools import partial
910

1011
import pytest
1112
import torch
@@ -16,6 +17,7 @@
1617
from torch import autograd, nn
1718
from torch.utils._pytree import tree_map
1819
from torchrl.modules import (
20+
IndependentNormal,
1921
OneHotCategorical,
2022
OneHotOrdinal,
2123
Ordinal,
@@ -169,6 +171,184 @@ def test_tanhnormal_event_dims(self, event_dims):
169171
exp_shape,
170172
)
171173

174+
@pytest.mark.parametrize("device", get_default_devices())
175+
@pytest.mark.parametrize(
176+
"callable_scale",
177+
[torch.ones_like, partial(torch.full_like, fill_value=0.5)],
178+
ids=["ones_like", "full_like_partial"],
179+
)
180+
def test_tanhnormal_callable_scale(self, device, callable_scale):
181+
"""Test that TanhNormal supports callable scale for compile-friendliness.
182+
183+
Using a callable scale (e.g., torch.ones_like or partial(torch.full_like, fill_value=...))
184+
avoids explicit device transfers and prevents graph breaks in torch.compile.
185+
"""
186+
torch.manual_seed(0)
187+
loc = torch.randn(3, 4, device=device)
188+
189+
# Create distribution with callable scale
190+
dist = TanhNormal(loc=loc, scale=callable_scale, low=-1, high=1)
191+
192+
# Check that the scale was properly resolved
193+
expected_scale = callable_scale(loc)
194+
torch.testing.assert_close(dist.scale, expected_scale)
195+
196+
# Test sampling
197+
sample = dist.sample()
198+
assert sample.shape == loc.shape
199+
assert sample.device == loc.device
200+
assert (sample >= -1).all()
201+
assert (sample <= 1).all()
202+
203+
# Test log_prob
204+
log_prob = dist.log_prob(sample)
205+
assert torch.isfinite(log_prob).all()
206+
207+
# Test rsample with gradient
208+
loc_grad = torch.randn(3, 4, device=device, requires_grad=True)
209+
dist_grad = TanhNormal(loc=loc_grad, scale=callable_scale, low=-1, high=1)
210+
sample_grad = dist_grad.rsample()
211+
loss = sample_grad.sum()
212+
loss.backward()
213+
assert loc_grad.grad is not None
214+
assert torch.isfinite(loc_grad.grad).all()
215+
216+
@pytest.mark.parametrize("device", get_default_devices())
217+
def test_tanhnormal_callable_scale_update(self, device):
218+
"""Test that TanhNormal.update() works with callable scale."""
219+
torch.manual_seed(0)
220+
loc = torch.randn(3, 4, device=device)
221+
callable_scale = torch.ones_like
222+
223+
dist = TanhNormal(loc=loc, scale=callable_scale, low=-1, high=1)
224+
225+
# Update with new loc and callable scale
226+
new_loc = torch.randn(3, 4, device=device)
227+
dist.update(new_loc, callable_scale)
228+
229+
# Check that scale was properly resolved
230+
torch.testing.assert_close(dist.scale, torch.ones_like(new_loc))
231+
232+
# Verify distribution works after update
233+
sample = dist.sample()
234+
assert sample.shape == new_loc.shape
235+
assert torch.isfinite(dist.log_prob(sample)).all()
236+
237+
238+
class TestIndependentNormal:
239+
@pytest.mark.parametrize("device", get_default_devices())
240+
@pytest.mark.parametrize(
241+
"callable_scale",
242+
[torch.ones_like, partial(torch.full_like, fill_value=0.5)],
243+
ids=["ones_like", "full_like_partial"],
244+
)
245+
def test_independentnormal_callable_scale(self, device, callable_scale):
246+
"""Test that IndependentNormal supports callable scale for compile-friendliness.
247+
248+
Using a callable scale (e.g., torch.ones_like or partial(torch.full_like, fill_value=...))
249+
avoids explicit device transfers and prevents graph breaks in torch.compile.
250+
"""
251+
torch.manual_seed(0)
252+
loc = torch.randn(3, 4, device=device)
253+
254+
# Create distribution with callable scale
255+
dist = IndependentNormal(loc=loc, scale=callable_scale)
256+
257+
# Check that the scale was properly resolved
258+
expected_scale = callable_scale(loc)
259+
torch.testing.assert_close(dist.base_dist.scale, expected_scale)
260+
261+
# Test sampling
262+
sample = dist.sample()
263+
assert sample.shape == loc.shape
264+
assert sample.device == loc.device
265+
266+
# Test log_prob
267+
log_prob = dist.log_prob(sample)
268+
assert torch.isfinite(log_prob).all()
269+
270+
# Test rsample with gradient
271+
loc_grad = torch.randn(3, 4, device=device, requires_grad=True)
272+
dist_grad = IndependentNormal(loc=loc_grad, scale=callable_scale)
273+
sample_grad = dist_grad.rsample()
274+
loss = sample_grad.sum()
275+
loss.backward()
276+
assert loc_grad.grad is not None
277+
assert torch.isfinite(loc_grad.grad).all()
278+
279+
@pytest.mark.parametrize("device", get_default_devices())
280+
def test_independentnormal_callable_scale_update(self, device):
281+
"""Test that IndependentNormal.update() works with callable scale."""
282+
torch.manual_seed(0)
283+
loc = torch.randn(3, 4, device=device)
284+
callable_scale = torch.ones_like
285+
286+
dist = IndependentNormal(loc=loc, scale=callable_scale)
287+
288+
# Update with new loc and callable scale
289+
new_loc = torch.randn(3, 4, device=device)
290+
dist.update(new_loc, callable_scale)
291+
292+
# Check that scale was properly resolved
293+
torch.testing.assert_close(dist.base_dist.scale, torch.ones_like(new_loc))
294+
295+
# Verify distribution works after update
296+
sample = dist.sample()
297+
assert sample.shape == new_loc.shape
298+
assert torch.isfinite(dist.log_prob(sample)).all()
299+
300+
@pytest.mark.parametrize("device", get_default_devices())
301+
@pytest.mark.parametrize("scale_type", ["tensor", "float", "callable"])
302+
def test_independentnormal_scale_types(self, device, scale_type):
303+
"""Test that IndependentNormal supports all scale types: tensor, float, callable."""
304+
torch.manual_seed(0)
305+
loc = torch.randn(3, 4, device=device)
306+
307+
if scale_type == "tensor":
308+
scale = torch.ones(3, 4, device=device)
309+
elif scale_type == "float":
310+
scale = 1.0
311+
else: # callable
312+
scale = torch.ones_like
313+
314+
dist = IndependentNormal(loc=loc, scale=scale)
315+
316+
# Test sampling
317+
sample = dist.sample()
318+
assert sample.shape == loc.shape
319+
assert sample.device == loc.device
320+
321+
# Test log_prob
322+
log_prob = dist.log_prob(sample)
323+
assert torch.isfinite(log_prob).all()
324+
325+
@pytest.mark.parametrize("device", get_default_devices())
326+
@pytest.mark.parametrize("scale_type", ["tensor", "float", "callable"])
327+
def test_tanhnormal_scale_types(self, device, scale_type):
328+
"""Test that TanhNormal supports all scale types: tensor, float, callable."""
329+
torch.manual_seed(0)
330+
loc = torch.randn(3, 4, device=device)
331+
332+
if scale_type == "tensor":
333+
scale = torch.ones(3, 4, device=device)
334+
elif scale_type == "float":
335+
scale = 1.0
336+
else: # callable
337+
scale = torch.ones_like
338+
339+
dist = TanhNormal(loc=loc, scale=scale, low=-1, high=1)
340+
341+
# Test sampling
342+
sample = dist.sample()
343+
assert sample.shape == loc.shape
344+
assert sample.device == loc.device
345+
assert (sample >= -1).all()
346+
assert (sample <= 1).all()
347+
348+
# Test log_prob
349+
log_prob = dist.log_prob(sample)
350+
assert torch.isfinite(log_prob).all()
351+
172352

173353
class TestTruncatedNormal:
174354
@pytest.mark.parametrize(

torchrl/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
import os
65
import warnings
76
import weakref
87
from warnings import warn

torchrl/modules/distributions/continuous.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
import weakref
8-
from collections.abc import Sequence
8+
from collections.abc import Callable, Sequence
99
from numbers import Number
1010

1111
import numpy as np
@@ -58,7 +58,11 @@ class IndependentNormal(D.Independent):
5858
5959
Args:
6060
loc (torch.Tensor): normal distribution location parameter
61-
scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
61+
scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
62+
Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
63+
Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
64+
avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
65+
in :func:`torch.compile`.
6266
upscale (torch.Tensor or number, optional): 'a' scaling factor in the formula:
6367
6468
.. math::
@@ -69,14 +73,28 @@ class IndependentNormal(D.Independent):
6973
tanh_loc (bool, optional): if ``False``, the above formula is used for
7074
the location scaling, otherwise the raw value
7175
is kept. Default is ``False``;
76+
77+
Example:
78+
>>> import torch
79+
>>> from functools import partial
80+
>>> from torchrl.modules.distributions import IndependentNormal
81+
>>> loc = torch.zeros(3, 4)
82+
>>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
83+
>>> dist = IndependentNormal(loc, scale=torch.ones_like)
84+
>>> # For a custom scale value, use partial to create a callable
85+
>>> dist = IndependentNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
86+
>>> sample = dist.sample()
87+
>>> sample.shape
88+
torch.Size([3, 4])
89+
7290
"""
7391

7492
num_params: int = 2
7593

7694
def __init__(
7795
self,
7896
loc: torch.Tensor,
79-
scale: torch.Tensor,
97+
scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor],
8098
upscale: float = 5.0,
8199
tanh_loc: bool = False,
82100
event_dim: int = 1,
@@ -86,11 +104,25 @@ def __init__(
86104
self.upscale = upscale
87105
self._event_dim = event_dim
88106
self._kwargs = kwargs
107+
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
108+
if callable(scale) and not isinstance(scale, torch.Tensor):
109+
scale = scale(loc)
110+
elif not isinstance(scale, torch.Tensor):
111+
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
112+
elif scale.device != loc.device:
113+
scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
89114
super().__init__(D.Normal(loc, scale, **kwargs), event_dim)
90115

91116
def update(self, loc, scale):
92117
if self.tanh_loc:
93118
loc = self.upscale * (loc / self.upscale).tanh()
119+
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
120+
if callable(scale) and not isinstance(scale, torch.Tensor):
121+
scale = scale(loc)
122+
elif not isinstance(scale, torch.Tensor):
123+
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
124+
elif scale.device != loc.device:
125+
scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
94126
super().__init__(D.Normal(loc, scale, **self._kwargs), self._event_dim)
95127

96128
@property
@@ -316,7 +348,11 @@ class TanhNormal(FasterTransformedDistribution):
316348
317349
Args:
318350
loc (torch.Tensor): normal distribution location parameter
319-
scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
351+
scale (torch.Tensor, float, or callable): normal distribution sigma parameter (squared root of variance).
352+
Can be a tensor, a float, or a callable that takes the ``loc`` tensor as input and returns the scale tensor.
353+
Using a callable (e.g., ``torch.ones_like`` or ``functools.partial(torch.full_like, fill_value=0.1)``)
354+
avoids explicit device transfers like ``torch.tensor(val, device=device)`` and prevents graph breaks
355+
in :func:`torch.compile`.
320356
upscale (torch.Tensor or number): 'a' scaling factor in the formula:
321357
322358
.. math::
@@ -331,6 +367,20 @@ class TanhNormal(FasterTransformedDistribution):
331367
value is kept. Default is ``False``;
332368
safe_tanh (bool, optional): if ``True``, the Tanh transform is done "safely", to avoid numerical overflows.
333369
This will currently break with :func:`torch.compile`.
370+
371+
Example:
372+
>>> import torch
373+
>>> from functools import partial
374+
>>> from torchrl.modules.distributions import TanhNormal
375+
>>> loc = torch.zeros(3, 4)
376+
>>> # Using a callable scale avoids device transfers and graph breaks in torch.compile
377+
>>> dist = TanhNormal(loc, scale=torch.ones_like)
378+
>>> # For a custom scale value, use partial to create a callable
379+
>>> dist = TanhNormal(loc, scale=partial(torch.full_like, fill_value=0.1))
380+
>>> sample = dist.sample()
381+
>>> sample.shape
382+
torch.Size([3, 4])
383+
334384
"""
335385

336386
arg_constraints = {
@@ -343,7 +393,7 @@ class TanhNormal(FasterTransformedDistribution):
343393
def __init__(
344394
self,
345395
loc: torch.Tensor,
346-
scale: torch.Tensor,
396+
scale: torch.Tensor | float | Callable[[torch.Tensor], torch.Tensor],
347397
upscale: torch.Tensor | Number = 5.0,
348398
low: torch.Tensor | Number = -1.0,
349399
high: torch.Tensor | Number = 1.0,
@@ -353,8 +403,14 @@ def __init__(
353403
):
354404
if not isinstance(loc, torch.Tensor):
355405
loc = torch.as_tensor(loc, dtype=torch.get_default_dtype())
356-
if not isinstance(scale, torch.Tensor):
357-
scale = torch.as_tensor(scale, dtype=torch.get_default_dtype())
406+
_non_blocking = loc.device.type == "cuda"
407+
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
408+
if callable(scale) and not isinstance(scale, torch.Tensor):
409+
scale = scale(loc)
410+
elif not isinstance(scale, torch.Tensor):
411+
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
412+
elif scale.device != loc.device:
413+
scale = scale.to(loc.device, non_blocking=_non_blocking)
358414
if event_dims is None:
359415
event_dims = min(1, loc.ndim)
360416

@@ -370,7 +426,6 @@ def __init__(
370426
if not all(high > low):
371427
raise RuntimeError(err_msg)
372428

373-
_non_blocking = loc.device.type == "cuda"
374429
if not isinstance(high, torch.Tensor):
375430
high = torch.as_tensor(high, device=loc.device)
376431
elif high.device != loc.device:
@@ -435,6 +490,13 @@ def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
435490
# loc must be rescaled if tanh_loc
436491
if is_compiling() or (self.non_trivial_max or self.non_trivial_min):
437492
loc = loc + (self.high - self.low) / 2 + self.low
493+
# Support callable scale (e.g., torch.ones_like) for compile-friendliness
494+
if callable(scale) and not isinstance(scale, torch.Tensor):
495+
scale = scale(loc)
496+
elif not isinstance(scale, torch.Tensor):
497+
scale = torch.as_tensor(scale, device=loc.device, dtype=loc.dtype)
498+
elif scale.device != loc.device:
499+
scale = scale.to(loc.device, non_blocking=loc.device.type == "cuda")
438500
self.loc = loc
439501
self.scale = scale
440502

torchrl/modules/distributions/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def _cast_transform_device(transform, device):
3232
for attribute in dir(transform):
3333
value = getattr(transform, attribute)
3434
if isinstance(value, torch.Tensor):
35-
setattr(transform, attribute, value.to(device, non_blocking=_non_blocking))
35+
setattr(
36+
transform, attribute, value.to(device, non_blocking=_non_blocking)
37+
)
3638
return transform
3739
else:
3840
raise TypeError(

0 commit comments

Comments
 (0)