diff --git a/src/main/java/BitOps.java b/src/main/java/BitOps.java index ad8b0e5..8b78eeb 100644 --- a/src/main/java/BitOps.java +++ b/src/main/java/BitOps.java @@ -1,6 +1,6 @@ package org.lichess.compression; -class BitOps { +public class BitOps { static int[] getBitMasks() { int[] mask = new int[32]; for (int i = 0; i < 32; i++) { @@ -8,4 +8,8 @@ static int[] getBitMasks() { } return mask; } + + public static int moduloPowerOfTwo(int dividend, int exponent) { + return dividend & ((1 << exponent) - 1); + } } diff --git a/src/main/java/game/Encoder.java b/src/main/java/game/Encoder.java index d68580f..bf9c0ae 100644 --- a/src/main/java/game/Encoder.java +++ b/src/main/java/game/Encoder.java @@ -1,19 +1,14 @@ package org.lichess.compression.game; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import java.util.HashSet; -import java.util.Set; -import java.util.regex.Pattern; -import java.util.regex.Matcher; - -import java.nio.ByteBuffer; - import org.lichess.compression.BitReader; import org.lichess.compression.BitWriter; +import org.lichess.compression.game.codec.Rans; + +import java.util.Arrays; public class Encoder { + private final Rans rans = new Rans(); + private static final ThreadLocal moveList = new ThreadLocal() { @Override protected MoveList initialValue() { @@ -21,101 +16,27 @@ protected MoveList initialValue() { } }; - private static Pattern SAN_PATTERN = Pattern.compile( - "([NBKRQ])?([a-h])?([1-8])?x?([a-h][1-8])(?:=([NBRQK]))?[\\+#]?"); - - private static Role charToRole(char c) { - switch (c) { - case 'N': return Role.KNIGHT; - case 'B': return Role.BISHOP; - case 'R': return Role.ROOK; - case 'Q': return Role.QUEEN; - case 'K': return Role.KING; - default: throw new IllegalArgumentException(); - } - } - - public static byte[] encode(String pgnMoves[]) { + public EncodeResult encode(String[] pgnMoves) { BitWriter writer = new BitWriter(); - - Board board = new Board(); - MoveList legals = moveList.get(); - - for (String pgnMove: pgnMoves) { - // Parse SAN. - Role role = null, promotion = null; - long from = Bitboard.ALL; - int to; - - if (pgnMove.startsWith("O-O-O")) { - role = Role.KING; - from = board.kings; - to = Bitboard.lsb(board.rooks & Bitboard.RANKS[board.turn ? 0 : 7]); - } else if (pgnMove.startsWith("O-O")) { - role = Role.KING; - from = board.kings; - to = Bitboard.msb(board.rooks & Bitboard.RANKS[board.turn ? 0 : 7]); - } else { - Matcher matcher = SAN_PATTERN.matcher(pgnMove); - if (!matcher.matches()) return null; - - String roleStr = matcher.group(1); - role = roleStr == null ? Role.PAWN : charToRole(roleStr.charAt(0)); - - if (matcher.group(2) != null) from &= Bitboard.FILES[matcher.group(2).charAt(0) - 'a']; - if (matcher.group(3) != null) from &= Bitboard.RANKS[matcher.group(3).charAt(0) - '1']; - - to = Square.square(matcher.group(4).charAt(0) - 'a', matcher.group(4).charAt(1) - '1'); - - if (matcher.group(5) != null) { - promotion = charToRole(matcher.group(5).charAt(0)); - } - } - - // Find index in legal moves. - board.legalMoves(legals); - legals.sort(); - - boolean foundMatch = false; - int size = legals.size(); - - for (int i = 0; i < size; i++) { - Move legal = legals.get(i); - if (legal.role == role && legal.to == to && legal.promotion == promotion && Bitboard.contains(from, legal.from)) { - if (!foundMatch) { - // Encode and play. - Huffman.write(i, writer); - board.play(legal); - foundMatch = true; - } - else return null; - } - } - - if (!foundMatch) return null; + int[] moveIndexes = GameToMoveIndexesConverter.convert(pgnMoves); + if (moveIndexes == null) { + return null; } - - return writer.toArray(); - } - - public static class DecodeResult { - public final String pgnMoves[]; - public final Board board; - public final int halfMoveClock; - public final byte positionHashes[]; - public final String lastUci; - - public DecodeResult(String pgnMoves[], Board board, int halfMoveClock, byte positionHashes[], String lastUci) { - this.pgnMoves = pgnMoves; - this.board = board; - this.halfMoveClock = halfMoveClock; - this.positionHashes = positionHashes; - this.lastUci = lastUci; + rans.resetEncoder(); + for (int i = moveIndexes.length - 1; i >= 0; i--) { + int moveIndex = moveIndexes[i]; + rans.write(moveIndex, writer); } + byte[] encoded = writer.toArray(); + byte[] encodedReversed = new byte[encoded.length]; + for (int i = 0; i < encoded.length; i++) { + encodedReversed[i] = encoded[encoded.length - (i + 1)]; + } + return new EncodeResult(encodedReversed, rans.getState()); } - public static DecodeResult decode(byte input[], int plies) { - BitReader reader = new BitReader(input); + public DecodeResult decode(EncodeResult input, int plies) { + BitReader reader = new BitReader(input.code); String output[] = new String[plies]; @@ -131,6 +52,8 @@ public static DecodeResult decode(byte input[], int plies) { byte positionHashes[] = new byte[3 * (plies + 1)]; setHash(positionHashes, -1, board.zobristHash()); + rans.initializeDecoder(input.state); + for (int i = 0; i <= plies; i++) { if (0 < i || i < plies) board.legalMoves(legals); @@ -142,7 +65,7 @@ public static DecodeResult decode(byte input[], int plies) { // Decode and play next move. if (i < plies) { legals.sort(); - Move move = legals.get(Huffman.read(reader)); + Move move = legals.get(rans.read(reader)); output[i] = san(move, legals); board.play(move); @@ -162,6 +85,32 @@ public static DecodeResult decode(byte input[], int plies) { lastUci); } + public static class EncodeResult { + public final byte[] code; + public final int state; + + public EncodeResult(byte[] code, int state) { + this.code = code; + this.state = state; + } + } + + public static class DecodeResult { + public final String[] pgnMoves; + public final Board board; + public final int halfMoveClock; + public final byte[] positionHashes; + public final String lastUci; + + public DecodeResult(String[] pgnMoves, Board board, int halfMoveClock, byte[] positionHashes, String lastUci) { + this.pgnMoves = pgnMoves; + this.board = board; + this.halfMoveClock = halfMoveClock; + this.positionHashes = positionHashes; + this.lastUci = lastUci; + } + } + private static String san(Move move, MoveList legals) { switch (move.type) { case Move.NORMAL: diff --git a/src/main/java/game/GameToMoveIndexesConverter.java b/src/main/java/game/GameToMoveIndexesConverter.java new file mode 100644 index 0000000..274a81b --- /dev/null +++ b/src/main/java/game/GameToMoveIndexesConverter.java @@ -0,0 +1,109 @@ +package org.lichess.compression.game; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class GameToMoveIndexesConverter +{ + private static final ThreadLocal moveList = new ThreadLocal() { + @Override + protected MoveList initialValue() { + return new MoveList(); + } + }; + + private static Pattern SAN_PATTERN = Pattern.compile( + "([NBKRQ])?([a-h])?([1-8])?x?([a-h][1-8])(?:=([NBRQK]))?[\\+#]?"); + + private static Role charToRole(char c) { + switch (c) { + case 'N': return Role.KNIGHT; + case 'B': return Role.BISHOP; + case 'R': return Role.ROOK; + case 'Q': return Role.QUEEN; + case 'K': return Role.KING; + default: throw new IllegalArgumentException(); + } + } + + public static int[] convert(String[] pgnMoves) { + Board board = new Board(); + MoveList legals = moveList.get(); + + int[] moveIndexes = new int[pgnMoves.length]; + + for (int ply = 0; ply < pgnMoves.length; ply++) { + String pgnMove = pgnMoves[ply]; + Lan current = parseSan(pgnMove, board); + + // Find index in legal moves. + board.legalMoves(legals); + legals.sort(); + + boolean foundMatch = false; + int size = legals.size(); + + for (int i = 0; i < size; i++) { + Move legal = legals.get(i); + if (legal.role == current.role && legal.to == current.to && legal.promotion == current.promotion && Bitboard.contains(current.from, legal.from)) { + if (!foundMatch) { + // Save and play. + moveIndexes[ply] = i; + board.play(legal); + foundMatch = true; + } else return null; + } + } + + if (!foundMatch) return null; + } + return moveIndexes; + } + + private static Lan parseSan(String pgnMove, Board board) { + Role role; + Role promotion = null; + long from = Bitboard.ALL; + int to; + + if (pgnMove.startsWith("O-O-O")) { + role = Role.KING; + from = board.kings; + to = Bitboard.lsb(board.rooks & Bitboard.RANKS[board.turn ? 0 : 7]); + } else if (pgnMove.startsWith("O-O")) { + role = Role.KING; + from = board.kings; + to = Bitboard.msb(board.rooks & Bitboard.RANKS[board.turn ? 0 : 7]); + } else { + Matcher matcher = SAN_PATTERN.matcher(pgnMove); + if (!matcher.matches()) return null; + + String roleStr = matcher.group(1); + role = roleStr == null ? Role.PAWN : charToRole(roleStr.charAt(0)); + + if (matcher.group(2) != null) from &= Bitboard.FILES[matcher.group(2).charAt(0) - 'a']; + if (matcher.group(3) != null) from &= Bitboard.RANKS[matcher.group(3).charAt(0) - '1']; + + to = Square.square(matcher.group(4).charAt(0) - 'a', matcher.group(4).charAt(1) - '1'); + + if (matcher.group(5) != null) { + promotion = charToRole(matcher.group(5).charAt(0)); + } + } + return new Lan(role, promotion, from, to); + } + + public static class Lan { + public final Role role; + public final Role promotion; + public long from; + public int to; + + public Lan(Role role, Role promotion, long from, int to) { + this.role = role; + this.promotion = promotion; + this.from = from; + this.to = to; + } + } +} diff --git a/src/main/java/game/Huffman.java b/src/main/java/game/Huffman.java deleted file mode 100644 index f1b4426..0000000 --- a/src/main/java/game/Huffman.java +++ /dev/null @@ -1,334 +0,0 @@ -package org.lichess.compression.game; - -import org.lichess.compression.BitReader; -import org.lichess.compression.BitWriter; - -class Huffman { - public static void write(int value, BitWriter writer) { - Symbol symbol = CODES[value]; - writer.writeBits(symbol.code, symbol.bits); - } - - public static int read(BitReader reader) { - Node node = ROOT; - while (node.zero != null && node.one != null) { - int bit = reader.readBits(1); - if (bit == 0) node = node.zero; - else node = node.one; - } - return node.leaf; - } - - private static class Symbol { - public final int code; - public final int bits; - - public Symbol(int code, int bits) { - this.code = code; - this.bits = bits; - } - } - - private static class Node { - public final Node zero; - public final Node one; - public final int leaf; - - public Node(int leaf) { - this.zero = null; - this.one = null; - this.leaf = leaf; - } - - public Node(Node zero, Node one) { - this.zero = zero; - this.one = one; - this.leaf = -1; - } - } - - private static Node buildTree(int code, int bits) { - assert bits <= 32; - - for (int i = 0; i <= 0xff; i++) { - if (CODES[i].code == code && CODES[i].bits == bits) { - return new Node(i); - } - } - - return new Node( - buildTree(code << 1, bits + 1), - buildTree((code << 1) | 1, bits + 1)); - } - - // Huffman code for indexes in the legal move list. Precomputed based on - // actual frequency in 16,232,215 rated games. - // - // This is based on a maximum of 256 legal moves per position, but the - // highest indexes did not actually occur. They were manually assigned a - // frequency of 1 and ordered. - // - // On the training corpus this achieves: - // 37.03 bytes per game - // 0.551 bytes per move - private static final Symbol CODES[] = { - new Symbol(0b00, 2), // 0: 225883932 (20.71%) - new Symbol(0b100, 3), // 1: 134956126 (12.37%) - new Symbol(0b1101, 4), // 2: 89041269 (8.16%) - new Symbol(0b1010, 4), // 3: 69386238 (6.36%) - new Symbol(0b0101, 4), // 4: 57040790 (5.23%) - new Symbol(0b11101, 5), // 5: 44974559 (4.12%) - new Symbol(0b10111, 5), // 6: 36547155 (3.35%) - new Symbol(0b01110, 5), // 7: 31624920 (2.90%) - new Symbol(0b01100, 5), // 8: 28432772 (2.61%) - new Symbol(0b01000, 5), // 9: 26540493 (2.43%) - new Symbol(0b111101, 6), // 10: 24484873 (2.24%) - new Symbol(0b111001, 6), // 11: 23058034 (2.11%) - new Symbol(0b111100, 6), // 12: 23535272 (2.16%) - new Symbol(0b110011, 6), // 13: 20482457 (1.88%) - new Symbol(0b110010, 6), // 14: 20450172 (1.87%) - new Symbol(0b110000, 6), // 15: 18316057 (1.68%) - new Symbol(0b101101, 6), // 16: 17214833 (1.58%) - new Symbol(0b101100, 6), // 17: 16964761 (1.56%) - new Symbol(0b011111, 6), // 18: 16530028 (1.52%) - new Symbol(0b011011, 6), // 19: 15369510 (1.41%) - new Symbol(0b010011, 6), // 20: 14178440 (1.30%) - new Symbol(0b011010, 6), // 21: 14275714 (1.31%) - new Symbol(0b1111111, 7), // 22: 13353306 (1.22%) - new Symbol(0b1111101, 7), // 23: 12829602 (1.18%) - new Symbol(0b1111110, 7), // 24: 13102592 (1.20%) - new Symbol(0b1111100, 7), // 25: 11932647 (1.09%) - new Symbol(0b1110000, 7), // 26: 10608657 (0.97%) - new Symbol(0b1100011, 7), // 27: 10142459 (0.93%) - new Symbol(0b0111101, 7), // 28: 8294594 (0.76%) - new Symbol(0b0100101, 7), // 29: 7337490 (0.67%) - new Symbol(0b0100100, 7), // 30: 6337744 (0.58%) - new Symbol(0b11100010, 8), // 31: 5380717 (0.49%) - new Symbol(0b11000101, 8), // 32: 4560556 (0.42%) - new Symbol(0b01111001, 8), // 33: 3913313 (0.36%) - new Symbol(0b111000111, 9), // 34: 3038767 (0.28%) - new Symbol(0b110001001, 9), // 35: 2480514 (0.23%) - new Symbol(0b011110001, 9), // 36: 1951026 (0.18%) - new Symbol(0b011110000, 9), // 37: 1521451 (0.14%) - new Symbol(0b1110001100, 10), // 38: 1183252 (0.11%) - new Symbol(0b1100010000, 10), // 39: 938708 (0.09%) - new Symbol(0b11100011010, 11), // 40: 673339 (0.06%) - new Symbol(0b11000100010, 11), // 41: 513153 (0.05%) - new Symbol(0b111000110110, 12), // 42: 377299 (0.03%) - new Symbol(0b110001000110, 12), // 43: 276996 (0.03%) - new Symbol(0b1110001101110, 13), // 44: 199682 (0.02%) - new Symbol(0b1100010001110, 13), // 45: 144602 (0.01%) - new Symbol(0b11100011011110, 14), // 46: 103313 (0.01%) - new Symbol(0b11000100011110, 14), // 47: 73046 (0.01%) - new Symbol(0b111000110111110, 15), // 48: 52339 (0.00%) - new Symbol(0b110001000111110, 15), // 49: 36779 (0.00%) - new Symbol(0b1110001101111110, 16), // 50: 26341 (0.00%) - new Symbol(0b1100010001111110, 16), // 51: 18719 (0.00%) - new Symbol(0b11000100011111111, 17), // 52: 13225 (0.00%) - new Symbol(0b111000110111111111, 18), // 53: 9392 (0.00%) - new Symbol(0b111000110111111101, 18), // 54: 6945 (0.00%) - new Symbol(0b110001000111111100, 18), // 55: 4893 (0.00%) - new Symbol(0b1110001101111111100, 19), // 56: 3698 (0.00%) - new Symbol(0b1100010001111111011, 19), // 57: 2763 (0.00%) - new Symbol(0b11100011011111111011, 20), // 58: 2114 (0.00%) - new Symbol(0b11100011011111110010, 20), // 59: 1631 (0.00%) - new Symbol(0b11100011011111110000, 20), // 60: 1380 (0.00%) - new Symbol(0b111000110111111110101, 21), // 61: 1090 (0.00%) - new Symbol(0b111000110111111100110, 21), // 62: 887 (0.00%) - new Symbol(0b111000110111111100010, 21), // 63: 715 (0.00%) - new Symbol(0b110001000111111101001, 21), // 64: 590 (0.00%) - new Symbol(0b110001000111111101000, 21), // 65: 549 (0.00%) - new Symbol(0b1110001101111111101000, 22), // 66: 477 (0.00%) - new Symbol(0b1110001101111111000110, 22), // 67: 388 (0.00%) - new Symbol(0b1100010001111111010111, 22), // 68: 351 (0.00%) - new Symbol(0b1100010001111111010101, 22), // 69: 319 (0.00%) - new Symbol(0b11100011011111111010011, 23), // 70: 262 (0.00%) - new Symbol(0b11100011011111110011110, 23), // 71: 236 (0.00%) - new Symbol(0b11100011011111110001110, 23), // 72: 200 (0.00%) - new Symbol(0b11100011011111110001111, 23), // 73: 210 (0.00%) - new Symbol(0b11000100011111110101100, 23), // 74: 153 (0.00%) - new Symbol(0b111000110111111100111011, 24), // 75: 117 (0.00%) - new Symbol(0b111000110111111110100100, 24), // 76: 121 (0.00%) - new Symbol(0b111000110111111100111111, 24), // 77: 121 (0.00%) - new Symbol(0b111000110111111100111010, 24), // 78: 115 (0.00%) - new Symbol(0b110001000111111101011011, 24), // 79: 95 (0.00%) - new Symbol(0b110001000111111101010011, 24), // 80: 75 (0.00%) - new Symbol(0b110001000111111101010001, 24), // 81: 67 (0.00%) - new Symbol(0b1110001101111111001110011, 25), // 82: 55 (0.00%) - new Symbol(0b1110001101111111001110001, 25), // 83: 50 (0.00%) - new Symbol(0b1110001101111111001110010, 25), // 84: 55 (0.00%) - new Symbol(0b1100010001111111010100101, 25), // 85: 33 (0.00%) - new Symbol(0b1100010001111111010110100, 25), // 86: 33 (0.00%) - new Symbol(0b1100010001111111010100001, 25), // 87: 30 (0.00%) - new Symbol(0b11100011011111110011111011, 26), // 88: 32 (0.00%) - new Symbol(0b11100011011111110011111001, 26), // 89: 28 (0.00%) - new Symbol(0b11100011011111110011111010, 26), // 90: 29 (0.00%) - new Symbol(0b11100011011111110011111000, 26), // 91: 27 (0.00%) - new Symbol(0b11000100011111110101101011, 26), // 92: 21 (0.00%) - new Symbol(0b111000110111111110100101111, 27), // 93: 15 (0.00%) - new Symbol(0b110001000111111101011010100, 27), // 94: 9 (0.00%) - new Symbol(0b110001000111111101011010101, 27), // 95: 10 (0.00%) - new Symbol(0b111000110111111100111000010, 27), // 96: 12 (0.00%) - new Symbol(0b111000110111111100111000011, 27), // 97: 12 (0.00%) - new Symbol(0b110001000111111101010010011, 27), // 98: 8 (0.00%) - new Symbol(0b1110001101111111101001010011, 28), // 99: 7 (0.00%) - new Symbol(0b1100010001111111010100100101, 28), // 100: 2 (0.00%) - new Symbol(0b1110001101111111001110000011, 28), // 101: 4 (0.00%) - new Symbol(0b1110001101111111001110000010, 28), // 102: 5 (0.00%) - new Symbol(0b1110001101111111001110000000, 28), // 103: 5 (0.00%) - new Symbol(0b11100011011111110011100000010, 29), // 104 - new Symbol(0b11000100011111110101000001001, 29), // 105: 5 (0.00%) - new Symbol(0b11100011011111110011100000011, 29), // 106: 1 (0.00%) - new Symbol(0b11000100011111110101000001000, 29), // 107: 1 (0.00%) - new Symbol(0b11000100011111110101000000011, 29), // 108 - new Symbol(0b110001000111111101010000011110, 30), // 109: 1 (0.00%) - new Symbol(0b111000110111111110100101100110, 30), // 110: 2 (0.00%) - new Symbol(0b111000110111111110100101010111, 30), // 111: 1 (0.00%) - new Symbol(0b110001000111111101010000001101, 30), // 112: 1 (0.00%) - new Symbol(0b111000110111111110100101100010, 30), // 113 - new Symbol(0b110001000111111101010000001000, 30), // 114 - new Symbol(0b110001000111111101010000000101, 30), // 115: 1 (0.00%) - new Symbol(0b110001000111111101010000000000, 30), // 116 - new Symbol(0b110001000111111101010000001010, 30), // 117 - new Symbol(0b110001000111111101010010001101, 30), // 118 - new Symbol(0b110001000111111101010010010011, 30), // 119 - new Symbol(0b110001000111111101010010010010, 30), // 120 - new Symbol(0b110001000111111101010010010001, 30), // 121 - new Symbol(0b110001000111111101010010010000, 30), // 122 - new Symbol(0b110001000111111101010010001011, 30), // 123 - new Symbol(0b110001000111111101010010001010, 30), // 124 - new Symbol(0b110001000111111101010010001001, 30), // 125 - new Symbol(0b110001000111111101010010001000, 30), // 126 - new Symbol(0b110001000111111101010010000111, 30), // 127 - new Symbol(0b110001000111111101010010000110, 30), // 128 - new Symbol(0b110001000111111101010010000011, 30), // 129 - new Symbol(0b110001000111111101010010000010, 30), // 130 - new Symbol(0b110001000111111101010000011011, 30), // 131 - new Symbol(0b110001000111111101010000011010, 30), // 132 - new Symbol(0b110001000111111101010000011001, 30), // 133 - new Symbol(0b110001000111111101010000011000, 30), // 134 - new Symbol(0b110001000111111101010000010101, 30), // 135 - new Symbol(0b110001000111111101010000010100, 30), // 136 - new Symbol(0b110001000111111101010010000101, 30), // 137 - new Symbol(0b110001000111111101010010000100, 30), // 138 - new Symbol(0b110001000111111101010000011111, 30), // 139 - new Symbol(0b110001000111111101010000011101, 30), // 140 - new Symbol(0b110001000111111101010000011100, 30), // 141 - new Symbol(0b110001000111111101010010000001, 30), // 142 - new Symbol(0b110001000111111101010010000000, 30), // 143 - new Symbol(0b110001000111111101010000001111, 30), // 144 - new Symbol(0b110001000111111101010000001110, 30), // 145 - new Symbol(0b110001000111111101010000001100, 30), // 146 - new Symbol(0b110001000111111101010000010111, 30), // 147 - new Symbol(0b110001000111111101010000010110, 30), // 148 - new Symbol(0b110001000111111101010000001001, 30), // 149 - new Symbol(0b110001000111111101010000000100, 30), // 150 - new Symbol(0b110001000111111101010000000011, 30), // 151 - new Symbol(0b110001000111111101010000000010, 30), // 152 - new Symbol(0b110001000111111101010000000001, 30), // 153 - new Symbol(0b110001000111111101010000001011, 30), // 154 - new Symbol(0b110001000111111101010010001111, 30), // 155 - new Symbol(0b110001000111111101010010001110, 30), // 156 - new Symbol(0b110001000111111101010010001100, 30), // 157 - new Symbol(0b1110001101111111101001010111101, 31), // 158 - new Symbol(0b1110001101111111101001010111111, 31), // 159 - new Symbol(0b1110001101111111101001010100010, 31), // 160 - new Symbol(0b1110001101111111101001011011111, 31), // 161 - new Symbol(0b1110001101111111101001010100100, 31), // 162 - new Symbol(0b1110001101111111101001010111001, 31), // 163 - new Symbol(0b1110001101111111101001011011010, 31), // 164 - new Symbol(0b1110001101111111101001011010010, 31), // 165 - new Symbol(0b1110001101111111101001011010000, 31), // 166 - new Symbol(0b1110001101111111101001010111010, 31), // 167 - new Symbol(0b1110001101111111101001010001011, 31), // 168 - new Symbol(0b1110001101111111101001010001010, 31), // 169 - new Symbol(0b1110001101111111101001010001001, 31), // 170 - new Symbol(0b1110001101111111101001010001000, 31), // 171 - new Symbol(0b1110001101111111101001010000111, 31), // 172 - new Symbol(0b1110001101111111101001010000110, 31), // 173 - new Symbol(0b1110001101111111101001010000101, 31), // 174 - new Symbol(0b1110001101111111101001010000100, 31), // 175 - new Symbol(0b1110001101111111101001011010111, 31), // 176 - new Symbol(0b1110001101111111101001011010110, 31), // 177 - new Symbol(0b1110001101111111101001011010101, 31), // 178 - new Symbol(0b1110001101111111101001011010100, 31), // 179 - new Symbol(0b1110001101111111101001010110111, 31), // 180 - new Symbol(0b1110001101111111101001010110110, 31), // 181 - new Symbol(0b1110001101111111101001010010101, 31), // 182 - new Symbol(0b1110001101111111101001010010100, 31), // 183 - new Symbol(0b1110001101111111101001010110101, 31), // 184 - new Symbol(0b1110001101111111101001010110100, 31), // 185 - new Symbol(0b1110001101111111101001010010111, 31), // 186 - new Symbol(0b1110001101111111101001010010110, 31), // 187 - new Symbol(0b1110001101111111101001010110001, 31), // 188 - new Symbol(0b1110001101111111101001010110000, 31), // 189 - new Symbol(0b1110001101111111101001010010011, 31), // 190 - new Symbol(0b1110001101111111101001010010010, 31), // 191 - new Symbol(0b1110001101111111101001011101101, 31), // 192 - new Symbol(0b1110001101111111101001011101100, 31), // 193 - new Symbol(0b1110001101111111101001011101011, 31), // 194 - new Symbol(0b1110001101111111101001011101010, 31), // 195 - new Symbol(0b1110001101111111101001011100111, 31), // 196 - new Symbol(0b1110001101111111101001011100110, 31), // 197 - new Symbol(0b1110001101111111101001010010001, 31), // 198 - new Symbol(0b1110001101111111101001010010000, 31), // 199 - new Symbol(0b1110001101111111101001011100011, 31), // 200 - new Symbol(0b1110001101111111101001011100010, 31), // 201 - new Symbol(0b1110001101111111101001011100001, 31), // 202 - new Symbol(0b1110001101111111101001011100000, 31), // 203 - new Symbol(0b1110001101111111101001011101001, 31), // 204 - new Symbol(0b1110001101111111101001011101000, 31), // 205 - new Symbol(0b1110001101111111101001010001111, 31), // 206 - new Symbol(0b1110001101111111101001010001110, 31), // 207 - new Symbol(0b1110001101111111101001010000011, 31), // 208 - new Symbol(0b1110001101111111101001010000010, 31), // 209 - new Symbol(0b1110001101111111101001010001101, 31), // 210 - new Symbol(0b1110001101111111101001010001100, 31), // 211 - new Symbol(0b1110001101111111101001011001111, 31), // 212 - new Symbol(0b1110001101111111101001011001110, 31), // 213 - new Symbol(0b1110001101111111101001010000001, 31), // 214 - new Symbol(0b1110001101111111101001010000000, 31), // 215 - new Symbol(0b1110001101111111101001011011001, 31), // 216 - new Symbol(0b1110001101111111101001011011000, 31), // 217 - new Symbol(0b1110001101111111101001011100101, 31), // 218 - new Symbol(0b1110001101111111101001011100100, 31), // 219 - new Symbol(0b1110001101111111101001010101101, 31), // 220 - new Symbol(0b1110001101111111101001010101100, 31), // 221 - new Symbol(0b1110001101111111101001010110011, 31), // 222 - new Symbol(0b1110001101111111101001010110010, 31), // 223 - new Symbol(0b1110001101111111101001010101001, 31), // 224 - new Symbol(0b1110001101111111101001010101000, 31), // 225 - new Symbol(0b1110001101111111101001011101111, 31), // 226 - new Symbol(0b1110001101111111101001011101110, 31), // 227 - new Symbol(0b1110001101111111101001011001011, 31), // 228 - new Symbol(0b1110001101111111101001011001010, 31), // 229 - new Symbol(0b1110001101111111101001011000011, 31), // 230 - new Symbol(0b1110001101111111101001011000010, 31), // 231 - new Symbol(0b1110001101111111101001010101011, 31), // 232 - new Symbol(0b1110001101111111101001010101010, 31), // 233 - new Symbol(0b1110001101111111101001011001001, 31), // 234 - new Symbol(0b1110001101111111101001011001000, 31), // 235 - new Symbol(0b1110001101111111101001011000111, 31), // 236 - new Symbol(0b1110001101111111101001011000110, 31), // 237 - new Symbol(0b1110001101111111101001011000001, 31), // 238 - new Symbol(0b1110001101111111101001011000000, 31), // 239 - new Symbol(0b1110001101111111101001010111100, 31), // 240 - new Symbol(0b1110001101111111101001010100111, 31), // 241 - new Symbol(0b1110001101111111101001010100110, 31), // 242 - new Symbol(0b1110001101111111101001010111110, 31), // 243 - new Symbol(0b1110001101111111101001010100011, 31), // 244 - new Symbol(0b1110001101111111101001010100001, 31), // 245 - new Symbol(0b1110001101111111101001010100000, 31), // 246 - new Symbol(0b1110001101111111101001011011110, 31), // 247 - new Symbol(0b1110001101111111101001010100101, 31), // 248 - new Symbol(0b1110001101111111101001011011101, 31), // 249 - new Symbol(0b1110001101111111101001011011100, 31), // 250 - new Symbol(0b1110001101111111101001010111000, 31), // 251 - new Symbol(0b1110001101111111101001011011011, 31), // 252 - new Symbol(0b1110001101111111101001011010001, 31), // 253 - new Symbol(0b1110001101111111101001011010011, 31), // 254 - new Symbol(0b1110001101111111101001010111011, 31), // 255 - }; - - private static final Node ROOT = buildTree(0, 0); -} diff --git a/src/main/java/game/codec/BinarySearch.java b/src/main/java/game/codec/BinarySearch.java new file mode 100644 index 0000000..297df64 --- /dev/null +++ b/src/main/java/game/codec/BinarySearch.java @@ -0,0 +1,14 @@ +package org.lichess.compression.game.codec; + +import java.util.Arrays; + +public class BinarySearch { + public static int search(int[] a, int key) { + int i = Arrays.binarySearch(a, key); + boolean keyNotFound = i < 0; + if (keyNotFound) { + return -i - 2; + } + return i; + } +} diff --git a/src/main/java/game/codec/FrequencyDistribution.java b/src/main/java/game/codec/FrequencyDistribution.java new file mode 100644 index 0000000..9832b27 --- /dev/null +++ b/src/main/java/game/codec/FrequencyDistribution.java @@ -0,0 +1,35 @@ +package org.lichess.compression.game.codec; + +public class FrequencyDistribution +{ + private final int[] cdf; + + public FrequencyDistribution(int[] frequencies) { + this.cdf = computeCdf(frequencies); + } + + private static int[] computeCdf(int[] frequencies) { + int[] cdf = new int[frequencies.length + 1]; + cdf[0] = 0; + for (int i = 1; i < cdf.length; i++) { + cdf[i] = cdf[i - 1] + frequencies[i - 1]; + } + return cdf; + } + + public int getFrequencyAt(int i) { + return cdf[i + 1] - cdf[i]; + } + + public int getCdfAt(int i) { + return cdf[i]; + } + + public int getQuantileFunction(int frequency) { + return BinarySearch.search(cdf, frequency); + } + + public int getNumberOfBins() { + return cdf.length - 1; + } +} diff --git a/src/main/java/game/codec/Rans.java b/src/main/java/game/codec/Rans.java new file mode 100644 index 0000000..84ccba4 --- /dev/null +++ b/src/main/java/game/codec/Rans.java @@ -0,0 +1,39 @@ +package org.lichess.compression.game.codec; + +import org.lichess.compression.BitReader; +import org.lichess.compression.BitWriter; + +public class Rans { + private final int[] MOVE_INDEX_FREQUENCIES = {205, 120, 77, 58, 47, 35, 27, 23, 20, 18, 16, 15, 15, 12, 13, 11, 10, 10, 10, 8, 7, 7, 7, 6, 6, 5, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + private final int NUMBER_OF_QUANTIZATION_BITS = 10; + private final int NUMBER_OF_NORMALIZATION_BITS = 13; + private final RansConfiguration configuration = new RansConfiguration(NUMBER_OF_QUANTIZATION_BITS, NUMBER_OF_NORMALIZATION_BITS); + private final RansEncoder encoder; + private final RansDecoder decoder; + + public Rans() { + this.encoder = new RansEncoder(configuration, MOVE_INDEX_FREQUENCIES); + this.decoder = new RansDecoder(configuration, MOVE_INDEX_FREQUENCIES); + } + + public void write(int symbol, BitWriter writer) { + encoder.write(symbol, writer); + } + + public int read(BitReader reader) { + int symbol = decoder.read(reader); + return symbol; + } + + public void resetEncoder() { + encoder.reset(); + } + + public void initializeDecoder(int stateAfterEncoding) { + decoder.initialize(stateAfterEncoding); + } + + public int getState() { + return encoder.getState(); + } +} diff --git a/src/main/java/game/codec/RansConfiguration.java b/src/main/java/game/codec/RansConfiguration.java new file mode 100644 index 0000000..0111e9d --- /dev/null +++ b/src/main/java/game/codec/RansConfiguration.java @@ -0,0 +1,35 @@ +package org.lichess.compression.game.codec; + +public final class RansConfiguration { + private final int numberOfQuantizationBits; + private final int numberOfNormalizationBits; + private final int numberOfBitsToReadAndWrite; + private final int stateLowerBound; + + public RansConfiguration(int numberOfQuantizationBits, int numberOfNormalizationBits) { + this.numberOfQuantizationBits = numberOfQuantizationBits; + this.numberOfNormalizationBits = numberOfNormalizationBits; + numberOfBitsToReadAndWrite = 8; + stateLowerBound = 1 << (numberOfNormalizationBits + numberOfQuantizationBits); + } + + public int getNumberOfQuantizationBits() { + return numberOfQuantizationBits; + } + + public int getNumberOfNormalizationBits() { + return numberOfNormalizationBits; + } + + public int getNumberOfBitsToReadAndWrite() { + return numberOfBitsToReadAndWrite; + } + + public int getStateLowerBound() { + return stateLowerBound; + } + + public boolean stateUnderflowed(int state) { + return state < stateLowerBound; + } +} diff --git a/src/main/java/game/codec/RansDecoder.java b/src/main/java/game/codec/RansDecoder.java new file mode 100644 index 0000000..2f69bf1 --- /dev/null +++ b/src/main/java/game/codec/RansDecoder.java @@ -0,0 +1,54 @@ +package org.lichess.compression.game.codec; + +import org.lichess.compression.BitOps; +import org.lichess.compression.BitReader; + +class RansDecoder { + private final RansConfiguration configuration; + private final FrequencyDistribution symbolDistribution; + private int state; + private int symbol; + + public RansDecoder(RansConfiguration configuration, int[] symbolFrequencies) { + this.configuration = configuration; + this.symbolDistribution = new FrequencyDistribution(symbolFrequencies); + } + + public int read(BitReader reader) { + decode(); + normalizeUnderflowedState(reader); + return symbol; + } + + public void initialize(int state) { + this.state = state; + } + + private void decode() { + symbol = computeNextSymbol(); + state = computePreviousState(); + } + + private int computeNextSymbol() { + int i = BitOps.moduloPowerOfTwo(state, configuration.getNumberOfQuantizationBits()); + int nextSymbol = symbolDistribution.getQuantileFunction(i); + return nextSymbol; + } + + private int computePreviousState() { + int previousState = symbolDistribution.getFrequencyAt(symbol) * (state >> configuration.getNumberOfQuantizationBits()) + BitOps.moduloPowerOfTwo(state, configuration.getNumberOfQuantizationBits()) - symbolDistribution.getCdfAt(symbol); + return previousState; + } + + private void normalizeUnderflowedState(BitReader reader) { + while (configuration.stateUnderflowed(state)) { + int bits = readBits(reader); + state = (state << configuration.getNumberOfBitsToReadAndWrite()) + bits; + } + } + + private int readBits(BitReader reader) { + int bits = reader.readBits(configuration.getNumberOfBitsToReadAndWrite()); + return bits; + } +} diff --git a/src/main/java/game/codec/RansEncoder.java b/src/main/java/game/codec/RansEncoder.java new file mode 100644 index 0000000..32c0688 --- /dev/null +++ b/src/main/java/game/codec/RansEncoder.java @@ -0,0 +1,69 @@ +package org.lichess.compression.game.codec; + +import org.lichess.compression.BitOps; +import org.lichess.compression.BitWriter; + +class RansEncoder { + private final RansConfiguration configuration; + private final FrequencyDistribution symbolDistribution; + private final int[] stateUpperBoundBySymbol; + private int state; + private int symbol; + + public RansEncoder(RansConfiguration configuration, int[] symbolFrequencies) { + this.configuration = configuration; + this.symbolDistribution = new FrequencyDistribution(symbolFrequencies); + this.stateUpperBoundBySymbol = computeStateUpperBoundBySymbol(); + } + + public void write(int symbol, BitWriter writer) { + this.symbol = symbol; + normalizeOverflowedState(writer); + encode(); + } + + public int getState() { + return state; + } + + public void reset() { + state = configuration.getStateLowerBound(); + } + + private int[] computeStateUpperBoundBySymbol() { + int[] stateUpperBoundBySymbol = new int[symbolDistribution.getNumberOfBins()]; + for (int symbol = 0; symbol < symbolDistribution.getNumberOfBins(); symbol++) { + int stateUpperBound = (1 << configuration.getNumberOfNormalizationBits()) * symbolDistribution.getFrequencyAt(symbol) * + (1 << configuration.getNumberOfBitsToReadAndWrite()) - 1; + stateUpperBoundBySymbol[symbol] = stateUpperBound; + } + return stateUpperBoundBySymbol; + } + + private void normalizeOverflowedState(BitWriter writer) { + while (nextStateOverflows()) { + writeBits(writer); + state >>= configuration.getNumberOfBitsToReadAndWrite(); + } + } + + private boolean nextStateOverflows() { + return state > stateUpperBoundBySymbol[symbol]; + } + + private int computeNextState() { + int nextState = (Math.floorDiv(state, symbolDistribution.getFrequencyAt(symbol)) << configuration.getNumberOfQuantizationBits()) + + symbolDistribution.getCdfAt(symbol) + (state % symbolDistribution.getFrequencyAt(symbol)); + return nextState; + } + + private void writeBits(BitWriter writer) { + int bits = BitOps.moduloPowerOfTwo(state, configuration.getNumberOfBitsToReadAndWrite()); + writer.writeBits(bits, configuration.getNumberOfBitsToReadAndWrite()); + } + + private void encode() { + int nextState = computeNextState(); + state = nextState; + } +} diff --git a/src/test/scala/HuffmanPgnTest.scala b/src/test/scala/RansPgnTest.scala similarity index 76% rename from src/test/scala/HuffmanPgnTest.scala rename to src/test/scala/RansPgnTest.scala index 0033442..2d96b2a 100644 --- a/src/test/scala/HuffmanPgnTest.scala +++ b/src/test/scala/RansPgnTest.scala @@ -2,8 +2,7 @@ package org.lichess.compression.game import org.specs2.mutable.* -class HuffmanPgnTest extends Specification: - +class RansPgnTest extends Specification: def hexToBytes(str: String) = str.grouped(2).map(cc => Integer.parseInt(cc, 16).toByte).toArray @@ -14,122 +13,129 @@ class HuffmanPgnTest extends Specification: "compress and decompress" in: forall(fixtures) { pgn => val pgnMoves = pgn.split(" ") - val encoded = Encoder.encode(pgnMoves) - val decoded = Encoder.decode(encoded, pgnMoves.size) - pgnMoves must_== decoded.pgnMoves - } - - "stable format" in: - forall(v1 zip fixtures) { case (encoded, pgn) => - val pgnMoves = pgn.split(" ") - val decoded = Encoder.decode(base64ToBytes(encoded), pgnMoves.size) + val encoder = Encoder() + val encoded = encoder.encode(pgnMoves) + val decoded = encoder.decode(encoded, pgnMoves.size) pgnMoves must_== decoded.pgnMoves } - "least surprise" in: - val n = 22 - val decoded = Encoder.decode(Array.fill(n)(0.toByte), n) - decoded.pgnMoves.mkString(" ") must_== "e4 e5 Nf3 Nf6 Nxe5 Nxe4 Nxf7 Kxf7 d4 Nxf2 Kxf2 d5 Nc3 Nc6 Nxd5 Qxd5 Kg1 Nxd4 Qxd4 Qxd4+ Be3 Qxe3#" - "unmoved rooks" in: import scala.jdk.CollectionConverters.* val pgnMoves = "d4 h5 c4 Rh6 Nf3 Rh8".split(" ") - val encoded = Encoder.encode(pgnMoves) + val encoder = Encoder() + val encoded = encoder.encode(pgnMoves) - val d1 = Encoder.decode(encoded, 0) + val d1 = encoder.decode(encoded, 0) Bitboard.squareSet(d1.board.castlingRights).asScala must_== Set(0, 7, 56, 63) - val d2 = Encoder.decode(encoded, pgnMoves.size) + val d2 = encoder.decode(encoded, pgnMoves.size) Bitboard.squareSet(d2.board.castlingRights).asScala must_== Set(0, 7, 56) "half-move clock" in: - val pgnMoves = "e4 e5 Nf3 Nc6 Nc3 Nf6 Bb5 d6 O-O Be7 d4 exd4 Nxd4 Bd7 Bg5 O-O Nxc6 bxc6 Bd3 h6 Bh4 Ne8 Bxe7 Qxe7 Qf3 Nf6 Rfe1 Rfe8".split(" ") - val encoded = Encoder.encode(pgnMoves) - val halfMoveClocks = List(0, 0, 0, 1, 2, 3, 4, 5, 0, 1, 2, 0, 0, 0, 1, 2, 3, 0, 0, 1, 0, 1, 2, 0, 0, 1, 2, 3, 4) - (0 to pgnMoves.size).map(Encoder.decode(encoded, _).halfMoveClock) must_== halfMoveClocks + val pgnMoves = + "e4 e5 Nf3 Nc6 Nc3 Nf6 Bb5 d6 O-O Be7 d4 exd4 Nxd4 Bd7 Bg5 O-O Nxc6 bxc6 Bd3 h6 Bh4 Ne8 Bxe7 Qxe7 Qf3 Nf6 Rfe1 Rfe8" + .split(" ") + val encoder = Encoder() + val encoded = encoder.encode(pgnMoves) + val halfMoveClocks = + List(0, 0, 0, 1, 2, 3, 4, 5, 0, 1, 2, 0, 0, 0, 1, 2, 3, 0, 0, 1, 0, 1, 2, 0, 0, 1, 2, 3, 4) + (0 to pgnMoves.size).map(encoder.decode(encoded, _).halfMoveClock) must_== halfMoveClocks "last uci" in: - val pgnMoves = "e4 e5 Nf3 Nc6 Bc4 Nf6 d4 exd4 O-O Bc5 e5 d5 exf6 dxc4 Re1+ Be6 Ng5 Qxf6 Nxe6 Qxe6".split(" ") - val encoded = Encoder.encode(pgnMoves) + val pgnMoves = + "e4 e5 Nf3 Nc6 Bc4 Nf6 d4 exd4 O-O Bc5 e5 d5 exf6 dxc4 Re1+ Be6 Ng5 Qxf6 Nxe6 Qxe6".split(" ") + val encoder = Encoder() + val encoded = encoder.encode(pgnMoves) - val empty = Encoder.decode(encoded, 0) + val empty = encoder.decode(encoded, 0) Option(empty.lastUci) must_== None - val decoded = Encoder.decode(encoded, pgnMoves.size) + val decoded = encoder.decode(encoded, pgnMoves.size) Option(decoded.lastUci) must_== Some("f6e6") "position hash 1. e4 d5 2. e5 f5 3. Ke2 Kf7" in: val pgnMoves = "e4 d5 e5 f5 Ke2 Kf7".split(" ") - val encoded = Encoder.encode(pgnMoves) + val encoder = Encoder() + val encoded = encoder.encode(pgnMoves) // initial position - val d0 = Encoder.decode(encoded, 0) + val d0 = encoder.decode(encoded, 0) d0.positionHashes must_== hexToBytes("463b96") // 1. e4 - val d1 = Encoder.decode(encoded, 1) + val d1 = encoder.decode(encoded, 1) d1.positionHashes must_== hexToBytes("823c9b") // 1. e4 d5 - val d2 = Encoder.decode(encoded, 2) + val d2 = encoder.decode(encoded, 2) d2.positionHashes must_== hexToBytes("0756b9") // 1. e4 d5 2. e5 - val d3 = Encoder.decode(encoded, 3) + val d3 = encoder.decode(encoded, 3) d3.positionHashes must_== hexToBytes("662faf") // 1. e4 d5 2. e5 f5 (en passant matters) - val d4 = Encoder.decode(encoded, 4) + val d4 = encoder.decode(encoded, 4) d4.positionHashes must_== hexToBytes("22a48b") // 1. e4 d5 2. e5 f5 3. Ke2 - val d5 = Encoder.decode(encoded, 5) + val d5 = encoder.decode(encoded, 5) d5.positionHashes must_== hexToBytes("652a60" + "22a48b") // 1. e4 d5 2. e5 f5 3. Ke2 Kf7 - val d6 = Encoder.decode(encoded, 6) + val d6 = encoder.decode(encoded, 6) d6.positionHashes must_== hexToBytes("00fdd3" + "652a60" + "22a48b") "position hash 1. a4 b5 2. h4 b4 3. c4 bxc3 4. Ra3" in: val pgnMoves = "a4 b5 h4 b4 c4 bxc3 Ra3".split(" ") - val encoded = Encoder.encode(pgnMoves) + val encoder = Encoder() + val encoded = encoder.encode(pgnMoves) // 1. a4 b5 2. h4 b4 3. c4 - val d5 = Encoder.decode(encoded, 5) + val d5 = encoder.decode(encoded, 5) d5.positionHashes must_== hexToBytes("3c8123") // 1. a4 b5 2. h4 b4 3. c4 bxc3 4. Ra3 - val d7 = Encoder.decode(encoded, 7) + val d7 = encoder.decode(encoded, 7) d7.positionHashes must_== hexToBytes("5c3f9b" + "93d326") "position hash threefold" in: // https://lichess.org/V0m3eSGN - val pgnMoves = "Nf3 d5 d4 c5 dxc5 e6 c4 Bxc5 Nc3 Nf6 e3 O-O cxd5 Nxd5 Nxd5 Qxd5 Qxd5 exd5 Be2 Nc6 a3 Bf5 b4 Bb6 Bb2 Rfd8 Rd1 Rac8 O-O Ne7 Nd4 Bg6 Rc1 Rxc1 Rxc1 Nf5 Bf3 Kf8 Nb3 Nxe3 Bd4 Nc2 Bxb6 axb6 Bd1 Re8 Bxc2 Bxc2 Nd4 Bd3 f3 Bc4 Kf2 Re5 g4 g6 Rc3 Ke7 Re3 Kf6 h4 Rxe3 Kxe3 Ke5 f4+ Kd6 g5 Ke7 Nf3 Ke6 Nd4+ Ke7 Nf3 Ke6 Nd4+ Ke7".split(" ") - val encoded = Encoder.encode(pgnMoves) - val decoded = Encoder.decode(encoded, pgnMoves.size) + val pgnMoves = + "Nf3 d5 d4 c5 dxc5 e6 c4 Bxc5 Nc3 Nf6 e3 O-O cxd5 Nxd5 Nxd5 Qxd5 Qxd5 exd5 Be2 Nc6 a3 Bf5 b4 Bb6 Bb2 Rfd8 Rd1 Rac8 O-O Ne7 Nd4 Bg6 Rc1 Rxc1 Rxc1 Nf5 Bf3 Kf8 Nb3 Nxe3 Bd4 Nc2 Bxb6 axb6 Bd1 Re8 Bxc2 Bxc2 Nd4 Bd3 f3 Bc4 Kf2 Re5 g4 g6 Rc3 Ke7 Re3 Kf6 h4 Rxe3 Kxe3 Ke5 f4+ Kd6 g5 Ke7 Nf3 Ke6 Nd4+ Ke7 Nf3 Ke6 Nd4+ Ke7" + .split(" ") + val encoder = Encoder() + val encoded = encoder.encode(pgnMoves) + val decoded = encoder.decode(encoded, pgnMoves.size) val threefold = "966379" val ncheck = "65afff" val ke6 = "1bc865" val nf3 = "e804e3" val g5 = "ef8a0b" - decoded.positionHashes must_== hexToBytes(threefold + ncheck + ke6 + nf3 + threefold + ncheck + ke6 + nf3 + threefold + g5) + decoded.positionHashes must_== hexToBytes( + threefold + ncheck + ke6 + nf3 + threefold + ncheck + ke6 + nf3 + threefold + g5 + ) "position hash compat" in: // https://lichess.org/DoqH1EQP - val pgnMoves = "e4 c5 Nf3 d6 d4 cxd4 Nxd4 Nc6 Nc3 g6 Be3 Bg7 Bc4 Nf6 f3 O-O Qd2 Nd7 O-O-O a5 g4 Nce5 Be2 a4 a3 Nb6 h4 Nbc4 Bxc4 Nxc4 Qf2 Qb6 b3 Nxe3 Qxe3 e5 Nf5 Qxe3+ Nxe3 axb3 cxb3 Rxa3 Kb2 Ra6 h5 h6 hxg6 fxg6 Ned5 Rxf3 Ne7+ Kf7 Nxc8 Ke6 Nxd6 Rf2+ Kb1 Rxd6 Nd5 Rc6 Rc1 Rxc1+ Rxc1 Re2 Rc7 Rxe4 Nb6 Bf8 Rxb7 Rb4 Rb8 Rxb3+ Kc2 Rb5 Rxf8 Rxb6 Rg8 Kf6 Rf8+ Kg5 Rh8 Rd6 Re8 Kxg4 Rxe5 g5 Re3 Kf5".split(" ") - val encoded = Encoder.encode(pgnMoves) - val decoded = Encoder.decode(encoded, pgnMoves.size) + val pgnMoves = + "e4 c5 Nf3 d6 d4 cxd4 Nxd4 Nc6 Nc3 g6 Be3 Bg7 Bc4 Nf6 f3 O-O Qd2 Nd7 O-O-O a5 g4 Nce5 Be2 a4 a3 Nb6 h4 Nbc4 Bxc4 Nxc4 Qf2 Qb6 b3 Nxe3 Qxe3 e5 Nf5 Qxe3+ Nxe3 axb3 cxb3 Rxa3 Kb2 Ra6 h5 h6 hxg6 fxg6 Ned5 Rxf3 Ne7+ Kf7 Nxc8 Ke6 Nxd6 Rf2+ Kb1 Rxd6 Nd5 Rc6 Rc1 Rxc1+ Rxc1 Re2 Rc7 Rxe4 Nb6 Bf8 Rxb7 Rb4 Rb8 Rxb3+ Kc2 Rb5 Rxf8 Rxb6 Rg8 Kf6 Rf8+ Kg5 Rh8 Rd6 Re8 Kxg4 Rxe5 g5 Re3 Kf5" + .split(" ") + val encoder = Encoder() + val encoded = encoder.encode(pgnMoves) + val decoded = encoder.decode(encoded, pgnMoves.size) decoded.positionHashes must_== base64ToBytes("oB9I1h1e6YDy") "work with all black legal moves in YycayYfM" in: // Exclude compression as cause of issues with https://lichess.org/YycayYfM val prefix = "e4 c6 Nf3 d5 exd5 cxd5 d4 Nc6 c3 Nf6 Bf4 Bg4 Be2 e6 Nbd2 Bd6 Bxd6 Qxd6 O-O O-O Re1 a6 Ne5 Bxe2 Qxe2 Nd7 Nxd7 Qxd7 a4 Rab8 Nf3 b5 axb5 axb5 Ne5 Nxe5 Qxe5 b4 c4 dxc4 Rac1 Rbc8 Qa5 Qb7 Re2 c3 bxc3 bxc3 Rec2 Qe4 Qe5 Qxe5 dxe5 Rc5 f4 Rfc8 Kf2 f6 exf6 gxf6 Ke3" val legals = "Kh8 Kf8 Kg7 Kf7 Rf8 Re8 Rd8 Rb8 Ra8 R8c7 R8c6 R5c7 R5c6 Rh5 Rg5 Rf5 Re5+ Rd5 Rb5 Ra5 Rc4 h6 f5 e5 h5".split(" ") + val encoder = Encoder() forall(legals) { legal => val pgnMoves = (prefix + " " + legal).split(" ") - val encoded = Encoder.encode(pgnMoves) - val decoded = Encoder.decode(encoded, pgnMoves.size) + val encoded = encoder.encode(pgnMoves) + val decoded = encoder.decode(encoded, pgnMoves.size) pgnMoves must_== decoded.pgnMoves } @@ -137,8 +143,9 @@ class HuffmanPgnTest extends Specification: // Exclude compression as cause of https://github.com/ornicar/lila/issues/5594 val prefix = "c4 e5 g3 h5 Nc3 h4 Bg2 Nf6 d3 Bb4 Bd2 d6 Nf3 h3 Bf1 Nc6 e3 Bg4 Be2 d5 Nxd5 Nxd5 cxd5 Qxd5 Bxb4 Nxb4 Qa4+ c6 Qxb4 Bxf3 Bxf3 Qxf3 Rg1 O-O-O Qe4 Qf6 O-O-O Rd5 f4 Rhd8 Rgf1 Qe6 Kb1 f5 Qc4 e4 d4 Kb8 Rc1 Qe7 Rg1 Qd7 Qc2 Re8 Qe2 Ra5 g4 g6 gxf5 gxf5 Qh5 Rd8 Qh6 c5 Rg7 Qa4 a3 Qb3 Qf6 Rc8 Qd6+ Ka8" val pgnMoves = s"$prefix Rxc5 Raxc5".split(" ") - val encoded = Encoder.encode(pgnMoves) - val decoded = Encoder.decode(encoded, pgnMoves.size) + val encoder = Encoder() + val encoded = encoder.encode(pgnMoves) + val decoded = encoder.decode(encoded, pgnMoves.size) pgnMoves must_== decoded.pgnMoves "pass perft test" in: @@ -290,120 +297,3 @@ class HuffmanPgnTest extends Specification: "e4 b6 d4 Bb7 Bd3 e6 c4 Bb4+ Nc3 Bxc3+ bxc3 h6 Nf3 Nf6 Qe2 O-O O-O d6 h3 Nbd7 a4 e5 Re1 a5 Nh2 Nh7 d5 Nc5 Bc2 Bc8 f4 Qh4 Rf1 Nf6 fxe5 dxe5 Nf3 Qg3 Kh1", "e4 d5 exd5 Qxd5 Nc3 Qd8 Bc4 Nf6 d3 Bg4 f3 Bf5 Be3 e6 Nge2 c6 Ng3 Bg6 Qd2 Bd6 Nce4 Nxe4 Nxe4 Bc7 Bb3 Ba5 c3 Bxe4 fxe4 O-O O-O-O b5 h4 Bb6 d4 Nd7 g4 c5 g5 cxd4 Bxd4 Bxd4 Qxd4 Nb8 Qe3 Qc7" ) - - val v1 = List( - "7qasJezzPJK15lj9CbbYheEA63S9DE37qYM/HcONsibhbJM/2xJqSwr/nVAX79Rn3x/vsAA=", - "KjTb/Zzt6FTIF/lVyHjtbeOzYeV9uhNzDfuV/699pPx/1XWiwVs31MA=", - "Mp0orWLvti0lxmh6kBmGf5IqTYEAdXvgx/3Jnivwhju9A6ImWcvOcc9n1FmEwA==", - "MhU6x0SImzC1OgAhmyHHSZLcNtUGucvp9TLlpoA=", - "PDdknk9du7oA11Y1tRdCpolRK+yysDyJ9z1Q", - "s3sbOnTq9vX15Npv7x4fJ97xFroPTbOLG+n9Q3639s5WH/7BQA==", - "Hw15mn6XrBtZGTjK0A==", - "Ugpnwa0n6QjIy8kUOHvF4vWAfNWGLlSpu4AGDXj+CUAfbz9YqN7Jq1sqLU2rqQmcfA==", - "19Sn95fufWzCaFjfjA==", - "Hw3FxvG9IvO817llUSMRAfV5xvfz/7Ez4kfrQWMgfg1NZz1n1D9gW0AqLQ==", - "LM7AHuV0bWvXpOGmfCsWvct0pUMzO1AV0XTqwBX68pFIq0XP0rvp2g==", - "zkmxSlt9/tqJ03f9B/Lbv8WnSSTNFLf11Wuz8dxe1//QCb9wv4i1/3JwIZWzXWaTwgMK9vGUr/ZVqt91WbzrlUA=", - "CGtPQ+eYf80+X/0C63rNmndIuJCBKzcx5b4u4nm3ha8Ul3updJ5mjhqLG4eUexdgdHNLHYWqpOg=", - "k8XmH/XqWTUtzb0FfAEODSU9gT2plavcE2l1IS+2kEcMmOdTr4P22vtsvWFnPMz9Zf3Ly1mBM0VfUV4TK/ZCZzu3AA==", - "k+fO9m7ugAzYRzplfkRc++fc6BmXoIFCkQ9zO7c6", - "BUL1i68rpzjsWPUJ2Pdk92pZJ6nCGB72yZwpuLCz0C7bn9wXPruaw0gbdav0rZ1ffFWg6A==", - "LOthhPsUt3Z477iYLiHYfBdAEvKxfMyDjTPrr0K7AhTjndfDLSzmoMa8tqkb30g=", - "AnSG1kzU5Yq/3W1r8pc++W3e8G/rdkvOFn93x/v+rkwd0PU11u9ULwgHmeCdv3BHXteI712hWYNcoA==", - "PDfxJxP78de/6HhPgcrkATDkByNP67npuEF2FeGmmgiA", - "mArWBeET583tvuNqHEQkAbLGO4J/OeGnlzebr+lIgaBv7cA=", - "ty73VQUu24W1KDl7tqeOC+TKoueDkJ+WLfYnAkA=", - "MF7OrJjpv1PueNH5uAZ4958JYA==", - "KnCPQhSuzsnzxJbZrPH/mK3pkFuPLbPWc98A", - "ysrQlz6bg1S5bgjzeeXJQA==", - "LFb0KkXcNZ3f2PBtQhABf4A=", - "Bg1PYNwxwIj/tXo=", - "KmPn/CTPMXaQz7Uae8eQ4uevEMA=", - "A618/WLNgHe98/vso2VXG9HicZOPF3sqzyMg", - "kvv8yK3ZWdJ+Iw/Bbb+r1P9r7H0w3LUN/gza+Y7c9wRAHZA=", - "k/0760KrEmRDhIzLFWHxrVuiH9Kr3+50x038r7eBXSSkP99u68l7WrvrGTlFQf6QKoMIXXUdWNE7AA==", - "Andh9gQNgU0xiZBfIqcgy86PbN9q/+SAuig=", - "K+8IExxnDUMbfTwtmne8O0su", - "CVSUMTl6Sc8/JxJ+pfX1Xa7lah9meFtCcE9FwUA=", - "kDe8/E+orRd88dwxtGhpTW9ARh5rEb/mYc4gmoA=", - "LNob+KTfOVRolcRnunH82UEhJgrmfdm1znTvc+Bj8NIM1PQC1GHviouxfWOuL2gdrjoXyqS4", - "VeEdjtGlsxcoa8Utvm92IsOEa866aNNLKkLxsA==", - "qCefpoSA3lN73H4DrIi8nznXtZhWpe47p0iSzXItfUA=", - "Hw3Gq73PyY3AB+unnhf0ucss", - "kN+zTf+9SWT92zEowhgma9xPUYw=", - "PBf5IfZfEMt13fvnNv5H38D16TuLwbXtaevbTQ==", - "WGHCpp5fvgWYfrjTWMjJL8VKerZcArLrTuM12HsGwZ756ZZKwtz6OLf+Qdct0L7EveqUsA==", - "KPwAiyY70yvnGX5k+YV/2S7n/VpPi9aA", - "yCVgblV+5/46gWBj3m0/sxpu/kw2mIS118A=", - "LI2BEkK/7hW4MfwSJF6A", - "PDfxISLpH9w5P6PcHee9xw1ecXry8azxeIh/+MhwB+5iKMg=", - "AKXNwe8rYUQnjecT", - "Mijq//lj3oo9eBUDTsjcpBiZAnhnUgB7tu4A", - "KPzZxNGQT1m2uxXaT8vP57BM2Av+FPkDJoxyeA/00A==", - "KKlErsNbydzbX/cHneNuv2SepoijLP4Gl79+gA==", - "ijQARzN4ef+W+bKHatW6bfjv17zlvQoB7dL/+vg0jvcLgA==", - "Hw3/RL0zFev0EarLk7tLPO80aL0=", - "Ar8BUYsLNYtZ8NrZA8/24fP35eGc1r1t+1nluSNy5A==", - "Cc1+r2sEPuqnP64hvbaJANL98J6A", - "KMeJCkLZxEJA0vfbz905zbTT635tQhDf3+rg", - "PP2up0ynHt0htYcq6g==", - "t3q/4EMWf6sv5lPOqqKb4I+GgA==", - "wW+5xenkBqqSUBxw51HLWE/rnlVRT4+Bu+oW3I33qN1Q", - "FjvmdkNJVB5+ZBO/Nf5b4Uxzexnq5edsx8Q=", - "yrTf297tPr69zzvTWISECwPNwfgD4A==", - "k+bvxZIbt0+pHi14k996oAV+85eGH/GdzCbIcfcaV/D8k/fEctOPbA==", - "DGoG0zKWZ3Ri5UseardCAA==", - "yNGyvxLz6rDSYlDJxgWfgA==", - "AlLU3aquoXcZKTWrbo8hJWj7XP/hAA6+w958dn9g7uphyLQ=", - "AnSvMbff/5P/TQW75A==", - "PBf5IW1ElYqorh374xIv7ux2P9VkmXDA7Va/h1bMsZ3SKFeGgA==", - "AnTzvDYHVmTFtxHydu/Lk5/R09nkJDQ=", - "AncICtSWO5R+idYO", - "3LKlNwT2x3SA/CYnfbjy7NIBG+uA", - "BfkxwR9+wWh31do2pC7h8q0UHVH7PMny+KA=", - "zqzheaUYv1au7CqrC6wsSW2s5sduyr8H+CVj0A==", - "PDfxH171aqzx7y0S4PCb6pBTbe/ANIaTdQNE", - "And38kC6LaCaV9dzuwA=", - "PKX6DR2s9l8qbzZz61Zn797XGvrxshsFz90X9xBb37yb9ep6zlfxLfcaoNAnkA==", - "Hw3FMUvXQy/dXu2s53zuhEtcTXf/L/Sbj4VvmbL2fO3CQcHee6F+b+dpVve4", - "g1JW3bIVjsu9rz7XYKZw", - "PGobF3d9dE6jaGJP+yCEu4/aR8akFFy9/A==", - "Kn0wuuDGvJvHmuyxQQ/1x+HAO8UkcZjngA==", - "kDS9c7zX14f2YOa7LSctZcc=", - "Ux1gmtqcnDGFbj2m5uEdY0B7B+72GGdzdIhJn6oBaA==", - "LIpO1gr+zHwWm1oZ9k3NHFbX7YYLoA==", - "Hw15mlUuGWs52FMgbi2Yvmg5Ozu8QAm27/++feV2Nt6X9AA=", - "Hw3FHmesQ7pDTLDtujGKhNkGkXmsdhDMyRnh5cA=", - "PGsBef76/gAFdX35X7zII3VJOs4uYA==", - "Hw15muXlxnv/mwEfKIGTLPmq1XxK/8zA", - "PK38njqKyEf191+aZvDR72/lnF/8dzoBkn6w", - "Aqu5uYBJFv8cH6q7TUjQb/8RscA=", - "tbvq0pKw8Tzx/STLSbF/O4rP0loMSBOOaA==", - "k+fZcHe+tUTJaqve4V5x/9ap3TPtZ8+tgdUuwA==", - "yLDopokZvj0PPspQvBnxtK6xgB+8AA==", - "3cIv1EaJIPehs793eyvZ5VQ8Zk3LI/z/bQtjLOzg7X38sag=", - "K/7LAczy54Y4TVa+ROYn6Te8FIUzvFVZQ80G84nF3TV/0rGqxuSgf2qwdHA=", - "g15tDk7MXCB+RJJpSKXTU5mQT0rd7YA=", - "mE7/kser33E8CE1uut9fzQ/YuIeYrV5rMEMzQFv4bztZNgA=", - "ngW6+L9e7v9CaENOjMwrsImE/A==", - "qC/W53U/M5PmXjfyr08sfeYA", - "k+K/u9Xh5jwkDGkzLdbe+5+1SbPF3/4vZld0Ijj8zZE4F/53rml7GgA=", - "CXyx/+8NAXes0vnfG37yxKzsZOOcGP8F4Pg=", - "mSl6sd++fLz19Kjl78v5vnFugwA=", - "Aqu6BVSbVSeQl77Iw/9A", - "AncJDmtxgp+9rG1IYJki3A==", - "Gb5fqN92K1wXqTIAHat075BvSXYarXeHUwYg", - "HwRZVB+8EdSBwxO/gzJtnefuJvzikOAkk+JY6wGuu/Zu/od+w8/219zQAA==", - "g03g3QQ/JqZVa0zBYvmN6pCPo4piXKbAZFJ5z+5+qiA=", - "BTers9z+p1zbk+G99SNLW0A=", - "k+Te7N29ZINd882eyt4Z2flvRzHg3AbA", - "yNKyHpcbseyQ9JfUnydJ2nmfh1iScNlQLWudN++/rRBXFoCqZnNZmIlSFheA", - "ylZ7XEFp1gDUqZO4fsQ0DPFR8xWqmf/i6mveCeIu9/bz0A==", - "7P9qO6ilfsxp1Y9nwnl4ANJG6VfJ8tD1/nKGAKG64EL7t9EIIXXqlKA=", - "Albpi7O7qd0nDJ2QVpxky7SfzOwR+vIuDGbH2F3+qA==", - "k+GyV5v1e84dMcCE22afckIDmA==", - "AjRJW+QgSMR9fZwwbXHpXXA=", - "lJ/o9i7j3b+eVTb/jdx/YA==", - "MBXVjuQzk9ZbSKw/IzSfmL9/8WX+4bn+", - "Hw15XVdxZr3eLz6xem8XLy8NLfv1/u0n/U0ifPQ=" - )