Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 106 additions & 103 deletions src/eval_constants.hpp

Large diffs are not rendered by default.

107 changes: 105 additions & 2 deletions src/eval_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "util/types.hpp"
#include <cassert>
#include <cmath>
#include <cstring>
#include <iostream>
#include <limits>
Expand Down Expand Up @@ -39,7 +40,7 @@
}

[[nodiscard]] inline auto mg() const {
const auto mg = static_cast<u16>(m_score);

Check warning on line 43 in src/eval_types.hpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/eval_types.hpp:43:20 [readability-identifier-naming]

invalid case style for constant 'mg'

i16 v{};
std::memcpy(&v, &mg, sizeof(mg));
Expand All @@ -48,7 +49,7 @@
}

[[nodiscard]] inline auto eg() const {
const auto eg = static_cast<u16>(static_cast<u32>(m_score + 0x8000) >> 16);

Check warning on line 52 in src/eval_types.hpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/eval_types.hpp:52:20 [readability-identifier-naming]

invalid case style for constant 'eg'

i16 v{};
std::memcpy(&v, &eg, sizeof(eg));
Expand Down Expand Up @@ -95,7 +96,7 @@

// Phasing between two scores
template<i32 max>
Value phase(i32 alpha) const {

Check warning on line 99 in src/eval_types.hpp

View workflow job for this annotation

GitHub Actions / Linter / cpp-linter

src/eval_types.hpp:99:5 [modernize-use-nodiscard]

function 'phase' should be marked [[nodiscard]]
assert(0 <= alpha && alpha <= max);
return static_cast<Value>((mg() * alpha + eg() * (max - alpha)) / max);
}
Expand All @@ -107,7 +108,6 @@
};

using PParam = PScore;

#else

using Score = Autograd::ValueHandle;
Expand All @@ -131,11 +131,114 @@
#define PSCORE_ZERO Autograd::PairHandle::create(0, 0)

#else
// ... (non-tuning definitions) ...
// ... (non-tuning definitions) ...
#define S(a, b) PScore((a), (b))
#define CS(a, b) PScore((a), (b))
#define PPARAM_ZERO PScore(0, 0)
#define PSCORE_ZERO PScore(0, 0)
#endif


// TunableSigmoid: a * sigmoid((x + c) / b)
// a and c are tunable pairs (mg, eg), b is a constant scale parameter
// For inference: uses a lookup table with linear interpolation in the 95% range
// 95% range is approximately [-b*ln(39), b*ln(39)] = [-3.664b, 3.664b]

template<i32 B_SCALE = 1000>
#ifdef EVAL_TUNING
class TunableSigmoid {
private:
PParam m_a; // Scaling parameter
PParam m_c; // Offset parameter

static constexpr f64 B = static_cast<f64>(B_SCALE);

public:
TunableSigmoid(i32 a0, i32 a1, i32 c0, i32 c1) :
m_a(S(a0, a1)),
m_c(S(c0, c1)) {
}

PScore operator()(PScore x) const {
auto scaled = x / B;
auto shifted = scaled + (m_c / B);
auto sig = shifted.sigmoid();
return m_a * sig;
}

PParam a() const {
return m_a;
}
i32 b() const {
return B_SCALE;
}
PParam c() const {
return m_c;
}
};
#else
class TunableSigmoid {
private:
static constexpr i32 TABLE_SIZE = 256;
static constexpr i32 FP_SHIFT = 16;
static constexpr i32 FP_ONE = 1 << FP_SHIFT;
static constexpr f64 B = static_cast<f64>(B_SCALE);
static constexpr f64 LN_39 = 3.6635616461296463;

struct Table {
i32 range_min;
i32 range_max;
i32 range_span;
i32 scale_fp;
std::array<i16, TABLE_SIZE> values;
};

Table m_mg;
Table m_eg;

public:
TunableSigmoid(i32 a_mg, i32 a_eg, i32 c_mg, i32 c_eg) {
build_table(m_mg, a_mg, c_mg);
build_table(m_eg, a_eg, c_eg);
}

PScore operator()(PScore x) const {
return PScore(lookup(x.mg(), m_mg), lookup(x.eg(), m_eg));
}

private:
static void build_table(Table& tbl, i32 a, i32 c) {
const f64 bound = B * LN_39;
tbl.range_min = static_cast<i32>(-bound) - c;
tbl.range_max = static_cast<i32>(bound) - c;
tbl.range_span = tbl.range_max - tbl.range_min;

tbl.scale_fp =
static_cast<i32>((static_cast<i64>(TABLE_SIZE - 1) << FP_SHIFT) / tbl.range_span);

for (i32 i = 0; i < TABLE_SIZE; ++i) {
const f64 t = static_cast<f64>(i) / (TABLE_SIZE - 1);
const f64 x = tbl.range_min + t * tbl.range_span;
const f64 z = (x + c) / B;
const f64 sig = 1.0 / (1.0 + std::exp(-z));

tbl.values[i] = static_cast<i16>(std::lround(a * sig));
}
}

static i16 lookup(i16 x_val, const Table& tbl) {
const i32 x = std::clamp(static_cast<i32>(x_val), tbl.range_min, tbl.range_max);
const i64 idx_fp = static_cast<i64>(x - tbl.range_min) * tbl.scale_fp;

const i32 idx = static_cast<i32>(idx_fp >> FP_SHIFT);
const i32 frac = static_cast<i32>(idx_fp & (FP_ONE - 1));

const i32 v0 = tbl.values[idx];
const i32 v1 = tbl.values[std::min(idx + 1, TABLE_SIZE - 1)];

return static_cast<i16>(v0 + ((v1 - v0) * frac >> FP_SHIFT));
}
};
#endif

} // namespace Clockwork
13 changes: 12 additions & 1 deletion src/evaltune_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ int main() {
Parameters current_parameter_values = Graph::get().get_all_parameter_values();

// Uncomment for zero tune: Overwrite them all with zeros.
current_parameter_values = Parameters::zeros(parameter_count);
current_parameter_values = Parameters::rand_init(parameter_count);

// The optimizer will now start with all-zero parameters
AdamW optim(parameter_count, 10, 0.9, 0.999, 1e-8, 0.0);
Expand Down Expand Up @@ -405,6 +405,17 @@ int main() {
print_table("BLOCKED_SHELTER_STORM", BLOCKED_SHELTER_STORM);
print_2d_array("SHELTER_STORM", SHELTER_STORM);

auto print_sigmoid = [](const std::string& name, const auto& sigmoid, const i32 templ) {
PairHandle a_h = static_cast<PairHandle>(sigmoid.a());
PairHandle c_h = static_cast<PairHandle>(sigmoid.c());
std::cout << "inline TunableSigmoid<" << templ << "> " << name << "(\n"
<< "\t" << std::lround(a_h.first()) << ", " << std::lround(a_h.second())
<< ", " << std::lround(c_h.first()) << ", " << std::lround(c_h.second())
<< "\n"
<< ")\n";
};
print_sigmoid("KING_SAFETY_ACTIVATION", KING_SAFETY_ACTIVATION, 32);

#endif
const auto end = time::Clock::now();
std::cout << "// Epoch duration: " << time::cast<time::FloatSeconds>(end - start).count()
Expand Down
47 changes: 34 additions & 13 deletions src/evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,11 @@ PScore evaluate_pawn_push_threats(const Position& pos) {
}

template<Color color>
PScore evaluate_pieces(const Position& pos) {
constexpr Color opp = ~color;
PScore eval = PSCORE_ZERO;
Bitboard own_pawns = pos.bitboard_for(color, PieceType::Pawn);
std::pair<PScore, PScore> evaluate_pieces(const Position& pos) {
constexpr Color opp = ~color;
PScore eval = PSCORE_ZERO;
PScore king_safety_score = PSCORE_ZERO;
Bitboard own_pawns = pos.bitboard_for(color, PieceType::Pawn);
Bitboard blocked_pawns =
own_pawns & pos.board().get_occupied_bitboard().shift_relative(color, Direction::South);
constexpr Bitboard early_ranks = color == Color::White
Expand All @@ -237,11 +238,11 @@ PScore evaluate_pieces(const Position& pos) {
Bitboard opp_king_ring = king_ring_table[pos.king_sq(opp).raw];
for (PieceId id : pos.get_piece_mask(color, PieceType::Knight)) {
eval += KNIGHT_MOBILITY[pos.mobility_of(color, id, ~bb)];
eval += KNIGHT_KING_RING[pos.mobility_of(color, id, opp_king_ring)];
king_safety_score += KNIGHT_KING_RING[pos.mobility_of(color, id, opp_king_ring)];
}
for (PieceId id : pos.get_piece_mask(color, PieceType::Bishop)) {
eval += BISHOP_MOBILITY[pos.mobility_of(color, id, ~bb)];
eval += BISHOP_KING_RING[pos.mobility_of(color, id, opp_king_ring)];
king_safety_score += BISHOP_KING_RING[pos.mobility_of(color, id, opp_king_ring)];
Square sq = pos.piece_list_sq(color)[id];
eval += BISHOP_PAWNS[std::min(
static_cast<usize>(8),
Expand All @@ -255,7 +256,7 @@ PScore evaluate_pieces(const Position& pos) {
for (PieceId id : pos.get_piece_mask(color, PieceType::Rook)) {
eval += ROOK_MOBILITY[pos.mobility_of(color, id, ~bb)];
eval += ROOK_MOBILITY[pos.mobility_of(color, id, ~bb2)];
eval += ROOK_KING_RING[pos.mobility_of(color, id, opp_king_ring)];
king_safety_score += ROOK_KING_RING[pos.mobility_of(color, id, opp_king_ring)];
// Rook lineups
Bitboard rook_file = Bitboard::file_mask(pos.piece_list_sq(color)[id].file());
eval += ROOK_LINEUP
Expand All @@ -268,15 +269,15 @@ PScore evaluate_pieces(const Position& pos) {
for (PieceId id : pos.get_piece_mask(color, PieceType::Queen)) {
eval += QUEEN_MOBILITY[pos.mobility_of(color, id, ~bb)];
eval += QUEEN_MOBILITY[pos.mobility_of(color, id, ~bb2)];
eval += QUEEN_KING_RING[pos.mobility_of(color, id, opp_king_ring)];
king_safety_score += QUEEN_KING_RING[pos.mobility_of(color, id, opp_king_ring)];
}
eval += KING_MOBILITY[pos.mobility_of(color, PieceId::king(), ~bb)];

if (pos.piece_count(color, PieceType::Bishop) >= 2) {
eval += BISHOP_PAIR_VAL;
}

return eval;
return {eval, king_safety_score};
}

template<Color color>
Expand Down Expand Up @@ -406,6 +407,13 @@ PScore evaluate_space(const Position& pos) {
return eval;
}

template<Color color>
PScore king_safety_activation(const Position& pos, PScore& king_safety_score) {
// Apply sigmoid activation to king safety score
PScore activated = KING_SAFETY_ACTIVATION(king_safety_score);
return activated;
}

Score evaluate_white_pov(const Position& pos, const PsqtState& psqt_state) {
const Color us = pos.active_color();
usize phase = pos.piece_count(Color::White, PieceType::Knight)
Expand All @@ -419,19 +427,32 @@ Score evaluate_white_pov(const Position& pos, const PsqtState& psqt_state) {
* (pos.piece_count(Color::White, PieceType::Queen)
+ pos.piece_count(Color::Black, PieceType::Queen));

phase = std::min<usize>(phase, 24);
phase = std::min<usize>(phase, 24);
PScore eval = psqt_state.score(); // Used for linear components

PScore eval = psqt_state.score();
eval += evaluate_pieces<Color::White>(pos) - evaluate_pieces<Color::Black>(pos);
// Pieces - get king safety scores directly
auto [white_piece_score, white_king_attack] = evaluate_pieces<Color::White>(pos);
auto [black_piece_score, black_king_attack] = evaluate_pieces<Color::Black>(pos);
eval += white_piece_score - black_piece_score;

// Other linear components
eval += evaluate_pawns<Color::White>(pos) - evaluate_pawns<Color::Black>(pos);
eval +=
evaluate_pawn_push_threats<Color::White>(pos) - evaluate_pawn_push_threats<Color::Black>(pos);
eval += evaluate_potential_checkers<Color::White>(pos)
- evaluate_potential_checkers<Color::Black>(pos);
eval += evaluate_threats<Color::White>(pos) - evaluate_threats<Color::Black>(pos);
eval += evaluate_king_safety<Color::White>(pos) - evaluate_king_safety<Color::Black>(pos);
eval += evaluate_space<Color::White>(pos) - evaluate_space<Color::Black>(pos);
eval += evaluate_outposts<Color::White>(pos) - evaluate_outposts<Color::Black>(pos);

// Nonlinear king safety components
PScore white_king_attack_total = white_king_attack + evaluate_king_safety<Color::Black>(pos);
PScore black_king_attack_total = black_king_attack + evaluate_king_safety<Color::White>(pos);

// Nonlinear adjustment
eval += king_safety_activation<Color::White>(pos, white_king_attack_total)
- king_safety_activation<Color::Black>(pos, black_king_attack_total);

eval += (us == Color::White) ? TEMPO_VAL : -TEMPO_VAL;
return static_cast<Score>(eval.phase<24>(static_cast<i32>(phase)));
};
Expand Down
77 changes: 77 additions & 0 deletions src/tuning/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,53 @@ PairHandle Graph::record_pair_value(OpType op, PairHandle lhs, ValueHandle rhs)
return out;
}

PairHandle Graph::record_pair_unary(OpType op, PairHandle input) {
PairHandle out = m_pairs.next_handle();
f64x2 in_val = m_pairs.val(input.index);
f64x2 res = f64x2::zero();

switch (op) {
case OpType::PairSigmoid: {
// Apply sigmoid to each component: 1 / (1 + exp(-x))
f64 mg = in_val.first();
f64 eg = in_val.second();
f64 sig_mg = 1.0 / (1.0 + std::exp(-mg));
f64 sig_eg = 1.0 / (1.0 + std::exp(-eg));
res = f64x2::make(sig_mg, sig_eg);
break;
}
default:
break;
}

m_pairs.alloc(res, f64x2::zero());

m_tape.push_back(Node::make_scalar(op, out.index, input.index, 0.0));

return out;
}

PairHandle Graph::record_pair_value(OpType op, PairHandle lhs, PairHandle rhs) {
PairHandle out = m_pairs.next_handle();
f64x2 l = m_pairs.val(lhs.index);
f64x2 r = m_pairs.val(rhs.index);
f64x2 res = f64x2::zero();

switch (op) {
case OpType::PairMulPair:
res = f64x2::mul(l, r);
break;
default:
break;
}

m_pairs.alloc(res, f64x2::zero());

m_tape.push_back(Node::make_binary(op, out.index, lhs.index, rhs.index));

return out;
}

ValueHandle Graph::record_phase(PairHandle lhs, f64 alpha) {
ValueHandle out = m_values.next_handle();
f64x2 pair_val = m_pairs.val(lhs.index);
Expand Down Expand Up @@ -335,6 +382,25 @@ void Graph::backward() {
break;
}

case OpType::PairSigmoid: {
const f64x2 grad_out = pair_grads[out_idx];

// sigmoid output values already computed in forward pass
f64x2 sigmoid_out = pair_vals[out_idx];

f64 sig_mg = sigmoid_out.first();
f64 sig_eg = sigmoid_out.second();

f64 grad_mg = sig_mg * (1.0 - sig_mg);
f64 grad_eg = sig_eg * (1.0 - sig_eg);

f64x2 local_grad = f64x2::make(grad_mg, grad_eg);
f64x2 update = f64x2::mul(local_grad, grad_out);

pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update);
break;
}

case OpType::PairAdd: {
const f64x2 grad_out = pair_grads[out_idx];
pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_out);
Expand Down Expand Up @@ -411,6 +477,17 @@ void Graph::backward() {
grad_out.first() * recip.first() + grad_out.second() * recip.second();
break;
}
case OpType::PairMulPair: {
const f64x2 grad_out = pair_grads[out_idx];
f64x2 l = pair_vals[node.lhs()];
f64x2 r = pair_vals[node.rhs()];

f64x2 grad_lhs = f64x2::mul(grad_out, r);
f64x2 grad_rhs = f64x2::mul(grad_out, l);
pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], grad_lhs);
pair_grads[node.rhs()] = f64x2::add(pair_grads[node.rhs()], grad_rhs);
break;
}

default:
unreachable();
Expand Down
2 changes: 2 additions & 0 deletions src/tuning/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class Graph {
PairHandle record_pair_op(OpType op, PairHandle lhs, PairHandle rhs);
PairHandle record_pair_scalar(OpType op, PairHandle input, f64 scalar);
PairHandle record_pair_value(OpType op, PairHandle pair, ValueHandle val);
PairHandle record_pair_value(OpType op, PairHandle lhs, PairHandle rhs);
PairHandle record_pair_unary(OpType op, PairHandle input);

ValueHandle record_phase(PairHandle input, f64 alpha);

Expand Down
Loading
Loading