diff --git a/demo/test.c b/demo/test.c index 2fa6e08d..fad237b9 100644 --- a/demo/test.c +++ b/demo/test.c @@ -1930,6 +1930,25 @@ static int test_mp_root_n(void) EXPECT(mp_cmp(&r, &c) == MP_EQ); } } + /* 0^(1/x) = 0 with x != 0 is allowed, test */ + mp_set(&a, 0); + DO(mp_root_n(&a, 2, &c)); + EXPECT(mp_cmp_d(&c, 0) == MP_EQ); + + /* Not allowed: division by zero */ + mp_set(&a, 2); + EXPECT(mp_root_n(&a, 0, &c) == MP_VAL); + + /* root^base == input with small input and base */ + mp_set(&a, 4); + DO(mp_root_n(&a, 2, &c)); + EXPECT(mp_cmp_d(&c, 2) == MP_EQ); + + /* (root^base)^(1/(base + 1)) with small root */ + DO(mp_2expt(&a, 48)); + DO(mp_root_n(&a, 49, &c)); + EXPECT(mp_cmp_d(&c, 1) == MP_EQ); + mp_clear_multi(&a, &c, &r, NULL); return EXIT_SUCCESS; LBL_ERR: diff --git a/mp_root_n.c b/mp_root_n.c index d904df88..f9f3b891 100644 --- a/mp_root_n.c +++ b/mp_root_n.c @@ -18,6 +18,22 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) int ilog2; mp_err err; + + if (b == 0) { + mp_set(c, 0); + return MP_VAL; + } + + /* 0^(1/x) = 0 with x != 0 is allowed */ + if (mp_iszero(a)) { + mp_set(c, 0); + if (b != 0) { + return MP_OKAY; + } else { + return MP_VAL; + } + } + if (b < 0 || (unsigned)b > (unsigned)MP_DIGIT_MAX) { return MP_VAL; } @@ -109,7 +125,8 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) cmp = mp_cmp(&t2, &a_); if (cmp == MP_EQ) { err = MP_OKAY; - goto LBL_ERR; + /* On point, skip overshoot correction */ + goto LBL_SET; } if (cmp == MP_LT) { if ((err = mp_add_d(&t1, 1uL, &t1)) != MP_OKAY) goto LBL_ERR; @@ -127,6 +144,7 @@ mp_err mp_root_n(const mp_int *a, int b, mp_int *c) } } +LBL_SET: /* set the result */ mp_exch(&t1, c);