Add derivative support for std::beta and introduce digamma function utility#1733
Add derivative support for std::beta and introduce digamma function utility#1733vgvassilev merged 2 commits intovgvassilev:masterfrom
Conversation
|
clang-tidy review says "All clean, LGTM! 👍" |
8c22bde to
b2b5388
Compare
|
clang-tidy review says "All clean, LGTM! 👍" |
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
| // CHECK-NEXT: return _t0.pushforward; | ||
| // CHECK-NEXT: } | ||
|
|
||
| double f_beta(double x, double y) { return std::beta(x, y); } |
There was a problem hiding this comment.
We need to also properly update test/Features/stl-cmath.cpp
b2b5388 to
3634001
Compare
|
clang-tidy review says "All clean, LGTM! 👍" |
3634001 to
9566db9
Compare
|
clang-tidy review says "All clean, LGTM! 👍" |
| namespace std { | ||
| template <typename T> | ||
| inline T beta(T x, T y) { | ||
| return std::tgamma(x) * std::tgamma(y) / std::tgamma(x + y); |
There was a problem hiding this comment.
We need to update the description on the top of this file accordingly.
There was a problem hiding this comment.
I have updated the file description!
9566db9 to
191e0c5
Compare
|
clang-tidy review says "All clean, LGTM! 👍" |
|
@guitargeek, can you take a look at this PR? |
|
Hi, thanks for the PR! How did you come up with this particular implementation of the digamma function? Is it reusing code from the GNU scientific library? You have reference for this "standard asymptotic expansion" that can be linked to inline in the code? |
|
@guitargeek |
|
Thanks. But how did you make the decision to stop expanding the series after 6 terms? I don't see anything about this in the Also, since the formula is an approximation, which is very important to know, I'd still prefer if this critical info is commented inline in the code. Another thing: since the Thanks you very much again for contributing on this important feature! |
|
Hey @guitargeek, thanks for the overview You are totally right about the compounding errors. Trusting AD to differentiate through an asymptotic approximation is definitely a bad idea for higher-order derivatives. I would love to add a trigamma implementation to this PR or a subsequent PR. As for the 6-term cutoff: once we push the input to x >= 8, the 6th term evaluates to ~2.4e-11 which puts us against the limit of a 64-bit double. I reckon adding anything beyond that just burns CPU cycles for no real gain in accuracy, i could implement it past the 6th term if you would like me to! As for the most important part, I'll get those inline comments updated and start working on the trigamma implementation. |
|
clang-tidy review says "All clean, LGTM! 👍" |
fba7d30 to
30a2e9d
Compare
|
clang-tidy review says "All clean, LGTM! 👍" |
| CUDA_HOST_DEVICE void beta_pullback(T x, T y, U d_z, T* d_x, T* d_y) { | ||
| T b = clad_beta_primal(x, y); | ||
| T psi_xy = clad_digamma(x + y); | ||
| if (d_x) |
There was a problem hiding this comment.
Can d_x and d_y be nullptr?
There was a problem hiding this comment.
Yes, If you only request the derivative with respect to one of the arguments, clad passes a nullptr for the inactive variable's adjoint.
The if checks are there to make sure we don't segfault when trying to write to them. I just followed the same pattern used by the other _pullback functions in this file to keep things consistent.
There was a problem hiding this comment.
Oh, I see. Thanks. It seems that we are not properly adding these if-checks for many pre-existing built-in derivatives functions.
guitargeek
left a comment
There was a problem hiding this comment.
Thanks, looks good to me at this point!
|
clang-tidy review says "All clean, LGTM! 👍" |
e122444 to
da650c6
Compare
|
clang-tidy review says "All clean, LGTM! 👍" |
da650c6 to
67b13b9
Compare
|
clang-tidy review says "All clean, LGTM! 👍" |
|
clang-tidy review says "All clean, LGTM! 👍" |
460f514 to
f850327
Compare
|
clang-tidy review says "All clean, LGTM! 👍" |
This function was introduced about 2 years ago as a quick fix to make `LnGamma` differentiable with Clad, but it was inconvenient because it depended on the external GSL library. Since vgvassilev/clad#1733, Clad natively implements a `digamma` function, so we can use this one now that we updated to Clad 2.3.
This function was introduced about 2 years ago as a quick fix to make `LnGamma` differentiable with Clad, but it was inconvenient because it depended on the external GSL library. Since vgvassilev/clad#1733, Clad natively implements a `digamma` function, so we can use this one now that we updated to Clad 2.3.
Added a digamma utility to into BuiltInDerivatives.h It handles the math using a standard asymptotic expansion and uses a recurrence shift for smaller values of x to ensure the floating-point results stay accurate.
implemented beta_pushforward and beta_pullback.