diff --git a/tests/test1.cpp b/tests/test1.cpp index 3273588..0d03cdd 100644 --- a/tests/test1.cpp +++ b/tests/test1.cpp @@ -316,6 +316,36 @@ static void test_pow2_weighted_roll_distribution() } } +static void test_quadratic_weighted_roll_distribution() +{ + seed s(42); + const int high = 4; + const int total_rolls = 1000000; + std::vector counts(high + 1, 0); + + for (int i = 0; i < total_rolls; ++i) + { + int result = s.quadratic_weighted_roll(high); + assert(result >= 0 && result <= high); + counts[result]++; + } + + uint64_t total_possibilities = (high + 1) * (high + 1); + + for (int i = 0; i <= high; ++i) + { + uint64_t weight = 2 * (high - i) + 1; + double expected_prob = (double)weight / total_possibilities; + int expected_count = (int)(total_rolls * expected_prob); + + int margin = expected_count / 10; // 10% margin of error + if (margin < 50) margin = 50; // minimum margin + + assert(counts[i] >= expected_count - margin); + assert(counts[i] <= expected_count + margin); + } +} + static void test_dice_guards() { seed s(123); @@ -372,6 +402,7 @@ int main(int argc, char **argv) linear_roll_table_test(); edge_cases(); test_pow2_weighted_roll_distribution(); + test_quadratic_weighted_roll_distribution(); test_dice_guards(); int j = 0;