Illustrate the usage of infer_discrete in the annotation example#1043
Illustrate the usage of infer_discrete in the annotation example#1043eb8680 merged 11 commits intopyro-ppl:masterfrom
Conversation
eb8680
left a comment
There was a problem hiding this comment.
Looks good to me.
Should we make some helpers like MCMC.infer_discrete or Predictive.infer_discrete that does the job in this annotation example under the hood?
Yes - in fact, this should probably be the default behavior of Predictive for HMC/NUTS when infer_discrete is finished.
| plates=terms["plate_vars"], | ||
| ) | ||
| log_prob = funsor.optimizer.apply_optimizer(log_prob) | ||
| with funsor.adjoint.AdjointTape() as tape: |
There was a problem hiding this comment.
Does using lazy here as in the old version and funsor.adjoint.adjoint below not work? I see we made this change in #991 but can't recall why.
|
|
||
| with approx: | ||
| approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) | ||
| approx_factors = tape.adjoint(sum_op, prod_op, log_prob) |
There was a problem hiding this comment.
Does using the more functional interface funsor.adjoint.adjoint not work here?
There was a problem hiding this comment.
Sorry for the late response, @eb8680! Yes, for some reason there is an error happens at this line.
File "/home/fehiepsi/numpyro/numpyro/contrib/funsor/discrete.py", line 120, in _sample_posterior
approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
File "/home/fehiepsi/funsor/funsor/adjoint.py", line 135, in adjoint
return tape.adjoint(sum_op, bin_op, root)
File "/home/fehiepsi/funsor/funsor/adjoint.py", line 115, in adjoint
in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
File "/home/fehiepsi/funsor/funsor/registry.py", line 106, in __call__
return self[key](*args)
File "/home/fehiepsi/funsor/funsor/registry.py", line 63, in __call__
return self.partial_call(*args)(*args)
File "/home/fehiepsi/funsor/funsor/adjoint.py", line 209, in adjoint_contract_generic
assert len(terms) == 1 or len(terms) == 2
AssertionError
|
Thanks for helping me review this PR, @eb8680!! |
It seems that we have some
infer_discretebug in the master branch. I merged changes in #991 and it works for models in this example.Changes in this PR:
replacehandler'sguide_traceargument (to be consistent with Pyro)Questions for reviewers: I found that it is tricky for users to perform
infer_discretecorrectly to get extra posterior samples for discrete latent sites. Should we make some helpers likeMCMC.infer_discreteorPredictive.infer_discretethat does the job in thisannotationexample under the hood?I imagine that it would be convenient if we have something like