|
4 | 4 | # See https://wiibrew.org/wiki/LZ77 for details about the LZ77 compression format. |
5 | 5 |
|
6 | 6 | import io |
| 7 | +from dataclasses import dataclass as _dataclass |
| 8 | + |
| 9 | + |
| 10 | +_LZ_MIN_DISTANCE = 0x01 # Minimum distance for each reference. |
| 11 | +_LZ_MAX_DISTANCE = 0x1000 # Maximum distance for each reference. |
| 12 | +_LZ_MIN_LENGTH = 0x03 # Minimum length for each reference. |
| 13 | +_LZ_MAX_LENGTH = 0x12 # Maximum length for each reference. |
| 14 | + |
| 15 | + |
| 16 | +@_dataclass |
| 17 | +class _LZNode: |
| 18 | + dist: int = 0 |
| 19 | + len: int = 0 |
| 20 | + weight: int = 0 |
| 21 | + |
| 22 | + |
| 23 | +def _compress_compare_bytes(byte1: bytes, offset1: int, byte2: bytes, offset2: int, abs_len_max: int) -> int: |
| 24 | + # Compare bytes up to the maximum length we can match. |
| 25 | + num_matched = 0 |
| 26 | + while num_matched < abs_len_max: |
| 27 | + if byte1[offset1 + num_matched] != byte2[offset2 + num_matched]: |
| 28 | + break |
| 29 | + num_matched += 1 |
| 30 | + return num_matched |
| 31 | + |
| 32 | + |
| 33 | +def _compress_search_matches(buffer: bytes, pos: int) -> (int, int): |
| 34 | + bytes_left = len(buffer) - pos |
| 35 | + global _LZ_MAX_DISTANCE, _LZ_MAX_LENGTH, _LZ_MIN_DISTANCE |
| 36 | + # Default to only looking back 4096 bytes, unless we've moved fewer than 4096 bytes, in which case we should |
| 37 | + # only look as far back as we've gone. |
| 38 | + max_dist = min(_LZ_MAX_DISTANCE, pos) |
| 39 | + # Default to only matching up to 18 bytes, unless fewer than 18 bytes remain, in which case we can only match |
| 40 | + # up to that many bytes. |
| 41 | + max_len = min(_LZ_MAX_LENGTH, bytes_left) |
| 42 | + # Log the longest match we found and its offset. |
| 43 | + biggest_match, biggest_match_pos = 0, 0 |
| 44 | + # Search for matches. |
| 45 | + for i in range(_LZ_MIN_DISTANCE, max_dist + 1): |
| 46 | + num_matched = _compress_compare_bytes(buffer, pos - i, buffer, pos, max_len) |
| 47 | + if num_matched > biggest_match: |
| 48 | + biggest_match = num_matched |
| 49 | + biggest_match_pos = i |
| 50 | + if biggest_match == max_len: |
| 51 | + break |
| 52 | + return biggest_match, biggest_match_pos |
| 53 | + |
| 54 | + |
| 55 | +def _compress_node_is_ref(node: _LZNode) -> bool: |
| 56 | + return node.len >= _LZ_MIN_LENGTH |
| 57 | + |
| 58 | + |
| 59 | +def _compress_get_node_cost(length: int) -> int: |
| 60 | + if length >= _LZ_MIN_LENGTH: |
| 61 | + num_bytes = 2 |
| 62 | + else: |
| 63 | + num_bytes = 1 |
| 64 | + return 1 + (num_bytes * 8) |
| 65 | + |
| 66 | + |
| 67 | +def compress_lz77(data: bytes) -> bytes: |
| 68 | + """ |
| 69 | + Compresses data using the Wii's LZ77 compression algorithm and returns the compressed result. |
| 70 | +
|
| 71 | + Parameters |
| 72 | + ---------- |
| 73 | + data: bytes |
| 74 | + The data to compress. |
| 75 | +
|
| 76 | + Returns |
| 77 | + ------- |
| 78 | + bytes |
| 79 | + The LZ77-compressed data. |
| 80 | + """ |
| 81 | + nodes = [_LZNode() for _ in range(len(data))] |
| 82 | + # Iterate over the uncompressed data, starting from the end. |
| 83 | + pos = len(data) |
| 84 | + global _LZ_MAX_LENGTH, _LZ_MIN_LENGTH, _LZ_MIN_DISTANCE |
| 85 | + while pos: |
| 86 | + pos -= 1 |
| 87 | + node = nodes[pos] |
| 88 | + # Limit the maximum search length when we're near the end of the file. |
| 89 | + max_search_len = min(_LZ_MAX_LENGTH, len(data) - pos) |
| 90 | + if max_search_len < _LZ_MIN_DISTANCE: |
| 91 | + max_search_len = 1 |
| 92 | + # Initialize as 1 for each, since that's all we could use if we weren't compressing. |
| 93 | + length, dist = 1, 1 |
| 94 | + if max_search_len >= _LZ_MIN_LENGTH: |
| 95 | + length, dist = _compress_search_matches(data, pos) |
| 96 | + # Treat as direct bytes if it's too short to copy. |
| 97 | + if length == 0 or length < _LZ_MIN_LENGTH: |
| 98 | + length = 1 |
| 99 | + # If the node goes to the end of the file, the weight is the cost of the node. |
| 100 | + if (pos + length) == len(data): |
| 101 | + node.len = length |
| 102 | + node.dist = dist |
| 103 | + node.weight = _compress_get_node_cost(length) |
| 104 | + # Otherwise, search for possible matches and determine the one with the best cost. |
| 105 | + else: |
| 106 | + weight_best = 0xFFFFFFFF # This was originally UINT_MAX, but that isn't a thing here so 32-bit it is! |
| 107 | + len_best = 1 |
| 108 | + while length: |
| 109 | + weight_next = nodes[pos + length].weight |
| 110 | + weight = _compress_get_node_cost(length) + weight_next |
| 111 | + if weight < weight_best: |
| 112 | + len_best = length |
| 113 | + weight_best = weight |
| 114 | + length -= 1 |
| 115 | + if length != 0 and length < _LZ_MIN_LENGTH: |
| 116 | + length = 1 |
| 117 | + node.len = len_best |
| 118 | + node.dist = dist |
| 119 | + node.weight = weight_best |
| 120 | + # Write the header data. |
| 121 | + with io.BytesIO() as buffer: |
| 122 | + # Write the header data. |
| 123 | + buffer.write(b'LZ77\x10') # The LZ type on the Wii is *always* 0x10. |
| 124 | + buffer.write(len(data).to_bytes(3, 'little')) |
| 125 | + |
| 126 | + src_pos = 0 |
| 127 | + while src_pos < len(data): |
| 128 | + head = 0 |
| 129 | + head_pos = buffer.tell() |
| 130 | + buffer.write(b'\x00') # Reserve a byte for the chunk head. |
| 131 | + |
| 132 | + i = 0 |
| 133 | + while i < 8 and src_pos < len(data): |
| 134 | + current_node = nodes[src_pos] |
| 135 | + length = current_node.len |
| 136 | + dist = current_node.dist |
| 137 | + # This is a reference node. |
| 138 | + if _compress_node_is_ref(current_node): |
| 139 | + encoded = (((length - _LZ_MIN_LENGTH) & 0xF) << 12) | ((dist - _LZ_MIN_DISTANCE) & 0xFFF) |
| 140 | + buffer.write(encoded.to_bytes(2)) |
| 141 | + head = (head | (1 << (7 - i))) & 0xFF |
| 142 | + # This is a direct copy node. |
| 143 | + else: |
| 144 | + buffer.write(data[src_pos:src_pos + 1]) |
| 145 | + src_pos += length |
| 146 | + i += 1 |
| 147 | + |
| 148 | + pos = buffer.tell() |
| 149 | + buffer.seek(head_pos) |
| 150 | + buffer.write(head.to_bytes(1)) |
| 151 | + buffer.seek(pos) |
| 152 | + |
| 153 | + buffer.seek(0) |
| 154 | + out_data = buffer.read() |
| 155 | + return out_data |
7 | 156 |
|
8 | 157 |
|
9 | 158 | def decompress_lz77(lz77_data: bytes) -> bytes: |
|
0 commit comments