Reparametrized normal with base noise exposed#601
Reparametrized normal with base noise exposed#601rfl-urbaniak wants to merge 5 commits intomasterfrom
Conversation
| def __init__( | ||
| self, | ||
| ): | ||
| super().__init__() |
| if isinstance(fn, dist.Independent): | ||
| base = fn.base_dist |
There was a problem hiding this comment.
Are we assuming that all Independent wrapped dists are Normal?
There was a problem hiding this comment.
Those to which you specifically point as the ones you want to reparametrize as reparametrized normal, I think so. Any reason to think otherwise?
There was a problem hiding this comment.
I think it's worth a check since reparam can be used automatically as well.
| ): | ||
| super().__init__() | ||
|
|
||
| def apply(self, msg): |
There was a problem hiding this comment.
def apply(self, msg: pyro.infer.reparam.reparam.ReparamMessage) -> pyro.infer.reparam.reparam.ReparamResult
| if value_indices != IndexSet(): | ||
| raise NotImplementedError("Partially observed Normal reparameterization is not implemented.") | ||
|
|
||
| if not is_observed: |
dimkab
left a comment
There was a problem hiding this comment.
reparam is not supposed to be used on observed sites. better to raise an exception.
| scale = base.scale | ||
| event_dim = fn.event_dim | ||
|
|
||
| if is_observed: |
There was a problem hiding this comment.
I suggest to raise an exception here. Reparametrization is not supposed to be used on observed sites, and in fact none of the tests do so. See e.g.
https://docs.pyro.ai/en/1.5.0/_modules/pyro/infer/reparam/loc_scale.html
| assert torch.allclose(y2["log_prob"], dist.Normal(loc, scale).log_prob(y2["value"])) | ||
|
|
||
| # reparam will encouter Delta not a Normal or Independent(Normal) | ||
| with pytest.raises(ValueError, match="NormalReparam only supports Normal or Independent\\(Normal\\)"): |
There was a problem hiding this comment.
again, reparam is not supposed to be used on observed sites, and therefore this part of the test is not needed
No description provided.