Skip to content

Reparametrized normal with base noise exposed#601

Open
rfl-urbaniak wants to merge 5 commits intomasterfrom
ru-reparametrized-normal
Open

Reparametrized normal with base noise exposed#601
rfl-urbaniak wants to merge 5 commits intomasterfrom
ru-reparametrized-normal

Conversation

@rfl-urbaniak
Copy link
Collaborator

No description provided.

@rfl-urbaniak rfl-urbaniak added the status:WIP Work-in-progress not yet ready for review label Oct 8, 2025
Comment on lines +10 to +13
def __init__(
self,
):
super().__init__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove me.

Comment on lines +21 to +22
if isinstance(fn, dist.Independent):
base = fn.base_dist
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we assuming that all Independent wrapped dists are Normal?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those to which you specifically point as the ones you want to reparametrize as reparametrized normal, I think so. Any reason to think otherwise?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth a check since reparam can be used automatically as well.

):
super().__init__()

def apply(self, msg):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def apply(self, msg: pyro.infer.reparam.reparam.ReparamMessage) -> pyro.infer.reparam.reparam.ReparamResult

if value_indices != IndexSet():
raise NotImplementedError("Partially observed Normal reparameterization is not implemented.")

if not is_observed:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:

Copy link
Contributor

@dimkab dimkab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reparam is not supposed to be used on observed sites. better to raise an exception.

scale = base.scale
event_dim = fn.event_dim

if is_observed:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to raise an exception here. Reparametrization is not supposed to be used on observed sites, and in fact none of the tests do so. See e.g.
https://docs.pyro.ai/en/1.5.0/_modules/pyro/infer/reparam/loc_scale.html

assert torch.allclose(y2["log_prob"], dist.Normal(loc, scale).log_prob(y2["value"]))

# reparam will encouter Delta not a Normal or Independent(Normal)
with pytest.raises(ValueError, match="NormalReparam only supports Normal or Independent\\(Normal\\)"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, reparam is not supposed to be used on observed sites, and therefore this part of the test is not needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

status:WIP Work-in-progress not yet ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants