diff --git a/src/mcts.c b/src/mcts.c index cb241fb..414219c 100644 --- a/src/mcts.c +++ b/src/mcts.c @@ -50,38 +50,72 @@ static fixed_point_t fixed_sqrt(fixed_point_t x) return s; } -static fixed_point_t fixed_log(fixed_point_t v) +#define LOG2_TABLE_SIZE 10 +static unsigned log2_table[1U << LOG2_TABLE_SIZE]; +static void log2_table_init(void) +{ + /* + * Q0.32 fixed-point representation of + * 2^{1/2}-1, 2^{1/4}-1, ..., 2^{2^{-32}}-1 + */ + const unsigned jump[32] = { + 1779033704, 812638371, 388727752, 190154448, 94047537, 46769127, + 23321248, 11644838, 5818478, 2908254, 1453881, 726879, + 363424, 181708, 90853, 45426, 22713, 11357, + 5678, 2839, 1420, 710, 355, 177, + 89, 44, 22, 11, 6, 3, + 1, 1}; + + for (unsigned i = 0; i < (1U << LOG2_TABLE_SIZE); ++i) { + /* + * Use binary search to find the largest log2(1+x) <= log2(1+i/1024) + */ + u64 target = (u64) i << (64 - LOG2_TABLE_SIZE), now = 0; + unsigned log = 0; + for (unsigned j = 0; j < 32; ++j) { + /* (1+now) * (1+jump) = 1 + now + jump + now*jump */ + u64 t = ((now + (1U << 31)) >> 32) * jump[j]; + if (now + ((u64) jump[j] << 32) + t <= target) { + now += ((u64) jump[j] << 32) + t; + log |= 1U << (31 - j); + } + } + log2_table[i] = log; + } +} + +static fixed_point_t fixed_log2(fixed_point_t v) { if (!v || v == (1U << FIXED_SCALE_BITS)) return 0; - fixed_point_t numerator = (v - (1U << FIXED_SCALE_BITS)); - int neg = 0; - if (GET_SIGN(numerator)) { - neg = 1; - numerator = CLR_SIGN(numerator); - numerator = (1U << 31) - numerator; - } + int log2_v = 15 - __builtin_clz(v); + v <<= (15 - log2_v); + fixed_point_t int_part = (unsigned) log2_v << FIXED_SCALE_BITS; - fixed_point_t y = ((u64) numerator << FIXED_SCALE_BITS) / - ((u64) v + (1U << FIXED_SCALE_BITS)); + unsigned index = (v ^ (1U << 31)) >> (31 - LOG2_TABLE_SIZE); + unsigned lower = log2_table[index]; + unsigned upper = + index == (1 << LOG2_TABLE_SIZE) - 1 ? 0 : log2_table[index + 1]; - fixed_point_t ans = 0U; - for (unsigned i = 1; i < 20; i += 2) { - fixed_point_t z = (1U << FIXED_SCALE_BITS); - for (int j = 0; j < i; j++) { - z = ((u64) z * y) >> FIXED_SCALE_BITS; - } - z = ((u64) z << FIXED_SCALE_BITS) / (i << FIXED_SCALE_BITS); + unsigned offset = v & ((1U << (31 - LOG2_TABLE_SIZE)) - 1); + u64 frac_part = + lower + + (((u64) (upper - lower) * offset + (1U << (30 - LOG2_TABLE_SIZE))) >> + (31 - LOG2_TABLE_SIZE)); - ans += z; - } - ans <<= 1; - ans = neg ? SET_SIGN(ans) : ans; - return ans; + unsigned result = int_part + ((frac_part + (1U << 15)) >> 16); + + /* Convert from 2's complement to signed representation */ + if (GET_SIGN(result)) + result = SET_SIGN(-result); + + return result; } -#define EXPLORATION_FACTOR fixed_sqrt(1U << (FIXED_SCALE_BITS + 1)) +#define SQRT_LOG_2 54562 +#define EXPLORATION_FACTOR \ + (fixed_sqrt(1U << (FIXED_SCALE_BITS + 1)) * SQRT_LOG_2 >> FIXED_SCALE_BITS) static inline fixed_point_t uct_score(int n_total, int n_visits, u64 score) { @@ -89,7 +123,7 @@ static inline fixed_point_t uct_score(int n_total, int n_visits, u64 score) return FIXED_MAX; fixed_point_t result = (fixed_point_t) (score / n_visits); - fixed_point_t log_val = fixed_log( + fixed_point_t log_val = fixed_log2( (n_total < 65536) ? (n_total << FIXED_SCALE_BITS) : FIXED_MAX); fixed_point_t tmp = ((u64) EXPLORATION_FACTOR * fixed_sqrt(log_val / n_visits)) >> @@ -219,5 +253,6 @@ int mcts(uint32_t table, char player) void mcts_init(void) { xoro_init(&(mcts_obj.xoro_obj)); + log2_table_init(); mcts_obj.nr_active_nodes = 0; }