Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 59 additions & 24 deletions src/mcts.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,46 +50,80 @@ static fixed_point_t fixed_sqrt(fixed_point_t x)
return s;
}

static fixed_point_t fixed_log(fixed_point_t v)
#define LOG2_TABLE_SIZE 10
static unsigned log2_table[1U << LOG2_TABLE_SIZE];
static void log2_table_init(void)
{
/*
* Q0.32 fixed-point representation of
* 2^{1/2}-1, 2^{1/4}-1, ..., 2^{2^{-32}}-1
*/
const unsigned jump[32] = {
1779033704, 812638371, 388727752, 190154448, 94047537, 46769127,
23321248, 11644838, 5818478, 2908254, 1453881, 726879,
363424, 181708, 90853, 45426, 22713, 11357,
5678, 2839, 1420, 710, 355, 177,
89, 44, 22, 11, 6, 3,
1, 1};

for (unsigned i = 0; i < (1U << LOG2_TABLE_SIZE); ++i) {
/*
* Use binary search to find the largest log2(1+x) <= log2(1+i/1024)
*/
u64 target = (u64) i << (64 - LOG2_TABLE_SIZE), now = 0;
unsigned log = 0;
for (unsigned j = 0; j < 32; ++j) {
/* (1+now) * (1+jump) = 1 + now + jump + now*jump */
u64 t = ((now + (1U << 31)) >> 32) * jump[j];
if (now + ((u64) jump[j] << 32) + t <= target) {
now += ((u64) jump[j] << 32) + t;
log |= 1U << (31 - j);
}
}
log2_table[i] = log;
}
}

static fixed_point_t fixed_log2(fixed_point_t v)
{
if (!v || v == (1U << FIXED_SCALE_BITS))
return 0;

fixed_point_t numerator = (v - (1U << FIXED_SCALE_BITS));
int neg = 0;
if (GET_SIGN(numerator)) {
neg = 1;
numerator = CLR_SIGN(numerator);
numerator = (1U << 31) - numerator;
}
int log2_v = 15 - __builtin_clz(v);
v <<= (15 - log2_v);
fixed_point_t int_part = (unsigned) log2_v << FIXED_SCALE_BITS;

fixed_point_t y = ((u64) numerator << FIXED_SCALE_BITS) /
((u64) v + (1U << FIXED_SCALE_BITS));
unsigned index = (v ^ (1U << 31)) >> (31 - LOG2_TABLE_SIZE);
unsigned lower = log2_table[index];
unsigned upper =
index == (1 << LOG2_TABLE_SIZE) - 1 ? 0 : log2_table[index + 1];

fixed_point_t ans = 0U;
for (unsigned i = 1; i < 20; i += 2) {
fixed_point_t z = (1U << FIXED_SCALE_BITS);
for (int j = 0; j < i; j++) {
z = ((u64) z * y) >> FIXED_SCALE_BITS;
}
z = ((u64) z << FIXED_SCALE_BITS) / (i << FIXED_SCALE_BITS);
unsigned offset = v & ((1U << (31 - LOG2_TABLE_SIZE)) - 1);
u64 frac_part =
lower +
(((u64) (upper - lower) * offset + (1U << (30 - LOG2_TABLE_SIZE))) >>
(31 - LOG2_TABLE_SIZE));

ans += z;
}
ans <<= 1;
ans = neg ? SET_SIGN(ans) : ans;
return ans;
unsigned result = int_part + ((frac_part + (1U << 15)) >> 16);

/* Convert from 2's complement to signed representation */
if (GET_SIGN(result))
result = SET_SIGN(-result);

return result;
}

#define EXPLORATION_FACTOR fixed_sqrt(1U << (FIXED_SCALE_BITS + 1))
#define SQRT_LOG_2 54562
#define EXPLORATION_FACTOR \
(fixed_sqrt(1U << (FIXED_SCALE_BITS + 1)) * SQRT_LOG_2 >> FIXED_SCALE_BITS)

static inline fixed_point_t uct_score(int n_total, int n_visits, u64 score)
{
if (n_visits == 0)
return FIXED_MAX;

fixed_point_t result = (fixed_point_t) (score / n_visits);
fixed_point_t log_val = fixed_log(
fixed_point_t log_val = fixed_log2(
(n_total < 65536) ? (n_total << FIXED_SCALE_BITS) : FIXED_MAX);
fixed_point_t tmp =
((u64) EXPLORATION_FACTOR * fixed_sqrt(log_val / n_visits)) >>
Expand Down Expand Up @@ -219,5 +253,6 @@ int mcts(uint32_t table, char player)
void mcts_init(void)
{
xoro_init(&(mcts_obj.xoro_obj));
log2_table_init();
mcts_obj.nr_active_nodes = 0;
}