diff --git a/src/noiseprotocol/crypto/aes_gcm.lua b/src/noiseprotocol/crypto/aes_gcm.lua index 51ef88c..58cb6ad 100644 --- a/src/noiseprotocol/crypto/aes_gcm.lua +++ b/src/noiseprotocol/crypto/aes_gcm.lua @@ -10,12 +10,12 @@ local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op --- Local references for performance (avoid module table lookups in hot loops) -local bit32_band = bit32.band -local bit32_bor = bit32.bor -local bit32_bxor = bit32.bxor -local bit32_lshift = bit32.lshift -local bit32_rshift = bit32.rshift +-- Local references for performance +local bit32_raw_band = bit32.raw_band +local bit32_raw_bor = bit32.raw_bor +local bit32_raw_bxor = bit32.raw_bxor +local bit32_raw_lshift = bit32.raw_lshift +local bit32_raw_rshift = bit32.raw_rshift local string_byte = string.byte local string_char = string.char local string_rep = string.rep @@ -302,50 +302,29 @@ local RCON = { 0x36, } ---- @alias AESGCMWord [integer, integer, integer, integer] ---- @alias AESGCMBlock [integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer] ---- @alias AESGCMState [AESGCMWord, AESGCMWord, AESGCMWord, AESGCMWord] +--- @alias AESWord [integer, integer, integer, integer] +--- @alias AESBlock [integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer] +--- @alias AESState [AESWord, AESWord, AESWord, AESWord] ---- Initialize a 16-element GCM block with zeros ---- @return AESGCMBlock block Initialized block -local function create_gcm_block() - local arr = {} - for i = 1, 16 do - arr[i] = 0 - end - --- @cast arr AESGCMBlock - return arr +--- Initialize a 4-element AES word with zeros +--- @return AESWord word Initialized word +local function create_aes_word() + --- @type AESState + return { 0, 0, 0, 0 } end --- Initialize a 4x4 AES state array with zeros ---- @return AESGCMState state Initialized state +--- @return AESState state Initialized state local function create_aes_state() - local state = {} - for i = 1, 4 do - state[i] = {} - for j = 1, 4 do - state[i][j] = 0 - end - end - --- @cast state AESGCMState - return state -end - ---- Initialize a 4-element AES word with zeros ---- @return AESGCMWord word Initialized word -local function create_aes_word() - local arr = {} - for i = 1, 4 do - arr[i] = 0 - end - --- @cast arr AESGCMWord - return arr + --- @type AESState + return { + create_aes_word(), + create_aes_word(), + create_aes_word(), + create_aes_word(), + } end --- Pre-allocated arrays for gcm_multiply() to avoid repeated allocation -local gcm_z = create_gcm_block() -local gcm_v = create_gcm_block() - -- Pre-allocated state array for aes_encrypt_block() local aes_state = create_aes_state() @@ -354,28 +333,28 @@ local mix_a = create_aes_word() local mix_b = create_aes_word() --- XOR two 4-byte words ---- @param a AESGCMWord 4-byte array ---- @param b AESGCMWord 4-byte array ---- @return table Word 4-byte array +--- @param a AESWord 4-byte array +--- @param b AESWord 4-byte array +--- @return AESWord result 4-byte array local function xor_words(a, b) return { - bit32_bxor(a[1], b[1]), - bit32_bxor(a[2], b[2]), - bit32_bxor(a[3], b[3]), - bit32_bxor(a[4], b[4]), + bit32_raw_bxor(a[1], b[1]), + bit32_raw_bxor(a[2], b[2]), + bit32_raw_bxor(a[3], b[3]), + bit32_raw_bxor(a[4], b[4]), } end --- Rotate word (circular left shift by 1 byte) ---- @param word AESGCMWord 4-byte array ---- @return AESGCMWord result Rotated 4-byte array +--- @param word AESWord 4-byte array +--- @return AESWord result Rotated 4-byte array local function rot_word(word) return { word[2], word[3], word[4], word[1] } end --- Apply S-box substitution to a word ---- @param word AESGCMWord 4-byte array ---- @return AESGCMWord result Substituted 4-byte array +--- @param word AESWord 4-byte array +--- @return AESWord result Substituted 4-byte array local function sub_word(word) local s_1 = assert(SBOX[word[1] + 1], "Invalid SBOX index " .. (word[1] + 1)) local s_2 = assert(SBOX[word[2] + 1], "Invalid SBOX index " .. (word[2] + 1)) @@ -407,7 +386,7 @@ local function key_expansion(key) end -- Convert key to words - --- @type AESGCMState + --- @type AESState local w = {} for i = 1, nk do w[i] = { @@ -435,7 +414,7 @@ local function key_expansion(key) end --- MixColumns transformation ---- @param state AESGCMState 4x4 state matrix +--- @param state AESState 4x4 state matrix local function mix_columns(state) -- Reuse pre-allocated arrays local a = mix_a @@ -443,19 +422,20 @@ local function mix_columns(state) for c = 1, 4 do for i = 1, 4 do a[i] = state[i][c] - b[i] = bit32_band(state[i][c], 0x80) ~= 0 and bit32_bxor(bit32_band(bit32_lshift(state[i][c], 1), 0xFF), 0x1B) - or bit32_band(bit32_lshift(state[i][c], 1), 0xFF) + b[i] = bit32_raw_band(state[i][c], 0x80) ~= 0 + and bit32_raw_bxor(bit32_raw_band(bit32_raw_lshift(state[i][c], 1), 0xFF), 0x1B) + or bit32_raw_band(bit32_raw_lshift(state[i][c], 1), 0xFF) end - state[1][c] = bit32_bxor(bit32_bxor(bit32_bxor(b[1], a[2]), bit32_bxor(b[2], a[3])), a[4]) - state[2][c] = bit32_bxor(bit32_bxor(bit32_bxor(a[1], b[2]), bit32_bxor(a[3], b[3])), a[4]) - state[3][c] = bit32_bxor(bit32_bxor(bit32_bxor(a[1], a[2]), bit32_bxor(b[3], a[4])), b[4]) - state[4][c] = bit32_bxor(bit32_bxor(bit32_bxor(a[1], b[1]), bit32_bxor(a[2], a[3])), b[4]) + state[1][c] = bit32_raw_bxor(bit32_raw_bxor(bit32_raw_bxor(b[1], a[2]), bit32_raw_bxor(b[2], a[3])), a[4]) + state[2][c] = bit32_raw_bxor(bit32_raw_bxor(bit32_raw_bxor(a[1], b[2]), bit32_raw_bxor(a[3], b[3])), a[4]) + state[3][c] = bit32_raw_bxor(bit32_raw_bxor(bit32_raw_bxor(a[1], a[2]), bit32_raw_bxor(b[3], a[4])), b[4]) + state[4][c] = bit32_raw_bxor(bit32_raw_bxor(bit32_raw_bxor(a[1], b[1]), bit32_raw_bxor(a[2], a[3])), b[4]) end end --- SubBytes transformation ---- @param state AESGCMState 4x4 state matrix +--- @param state AESState 4x4 state matrix local function sub_bytes(state) for i = 1, 4 do for j = 1, 4 do @@ -466,7 +446,7 @@ local function sub_bytes(state) end --- ShiftRows transformation ---- @param state AESGCMState 4x4 state matrix +--- @param state AESState 4x4 state matrix local function shift_rows(state) -- Row 1: no shift -- Row 2: shift left by 1 @@ -493,14 +473,14 @@ local function shift_rows(state) end --- AddRoundKey transformation ---- @param state AESGCMState 4x4 state matrix +--- @param state AESState 4x4 state matrix --- @param round_key table Round key words --- @param round integer Round number local function add_round_key(state, round_key, round) for c = 1, 4 do local key_word = round_key[round * 4 + c] for r = 1, 4 do - state[r][c] = bit32_bxor(state[r][c], key_word[r]) + state[r][c] = bit32_raw_bxor(state[r][c], key_word[r]) end end end @@ -552,6 +532,21 @@ end -- GCM MODE IMPLEMENTATION -- ============================================================================ +--- Initialize a 16-element GCM block with zeros +--- @return AESBlock block Initialized block +local function create_gcm_block() + local arr = {} + for i = 1, 16 do + arr[i] = 0 + end + --- @cast arr AESBlock + return arr +end + +-- Pre-allocated arrays for gcm_multiply() to avoid repeated allocation +local gcm_z = create_gcm_block() +local gcm_v = create_gcm_block() + --- GCM field multiplication --- @param x string 16-byte block --- @param y string 16-byte block @@ -570,27 +565,27 @@ local function gcm_multiply(x, y) for i = 1, 16 do local byte = string_byte(x, i) for bit = 7, 0, -1 do - if bit32_band(byte, bit32_lshift(1, bit)) ~= 0 then + if bit32_raw_band(byte, bit32_raw_lshift(1, bit)) ~= 0 then -- z = z XOR v for j = 1, 16 do - z[j] = bit32_bxor(z[j], v[j]) + z[j] = bit32_raw_bxor(z[j], v[j]) end end -- Check if LSB of v is 1 (bit 0 of last byte) - local lsb = bit32_band(v[16], 1) + local lsb = bit32_raw_band(v[16], 1) -- v = v >> 1 (right shift entire 128-bit value by 1) local carry = 0 for j = 1, 16 do - local new_carry = bit32_band(v[j], 1) - v[j] = bit32_bor(bit32_rshift(v[j], 1), bit32_lshift(carry, 7)) + local new_carry = bit32_raw_band(v[j], 1) + v[j] = bit32_raw_bor(bit32_raw_rshift(v[j], 1), bit32_raw_lshift(carry, 7)) carry = new_carry end -- If LSB was 1, XOR with R = 0xE1000000000000000000000000000000 if lsb ~= 0 then - v[1] = bit32_bxor(v[1], 0xE1) + v[1] = bit32_raw_bxor(v[1], 0xE1) end end end @@ -617,7 +612,7 @@ local function ghash(h, data) -- y = (y XOR block) * h local y_xor = "" for j = 1, 16 do - y_xor = y_xor .. string_char(bit32_bxor(string_byte(y, j), string_byte(block, j))) + y_xor = y_xor .. string_char(bit32_raw_bxor(string_byte(y, j), string_byte(block, j))) end y = gcm_multiply(y_xor, h) @@ -642,7 +637,7 @@ local function inc_counter(counter) -- Convert back to bytes (big-endian) for i = 3, 0, -1 do - result = result .. string_char(bit32_band(bit32_rshift(val, i * 8), 0xFF)) + result = result .. string_char(bit32_raw_band(bit32_raw_rshift(val, i * 8), 0xFF)) end return result @@ -762,7 +757,7 @@ function aes_gcm.encrypt(key, nonce, plaintext, aad) local keystream = generate_keystream(key, nonce, #plaintext) local ciphertext = "" for i = 1, #plaintext do - ciphertext = ciphertext .. string_char(bit32_bxor(string_byte(plaintext, i), string_byte(keystream, i))) + ciphertext = ciphertext .. string_char(bit32_raw_bxor(string_byte(plaintext, i), string_byte(keystream, i))) end -- Calculate authentication tag @@ -773,7 +768,7 @@ function aes_gcm.encrypt(key, nonce, plaintext, aad) local encrypted_j0 = aes_encrypt_block(j0, expanded_key, nr) local tag = "" for i = 1, 16 do - tag = tag .. string_char(bit32_bxor(string_byte(s, i), string_byte(encrypted_j0, i))) + tag = tag .. string_char(bit32_raw_bxor(string_byte(s, i), string_byte(encrypted_j0, i))) end return ciphertext .. tag @@ -841,7 +836,7 @@ function aes_gcm.decrypt(key, nonce, ciphertext_and_tag, aad) local encrypted_j0 = aes_encrypt_block(j0, expanded_key, nr) local expected_tag = "" for i = 1, 16 do - expected_tag = expected_tag .. string_char(bit32_bxor(string_byte(s, i), string_byte(encrypted_j0, i))) + expected_tag = expected_tag .. string_char(bit32_raw_bxor(string_byte(s, i), string_byte(encrypted_j0, i))) end -- Verify tag (constant-time comparison) @@ -853,7 +848,7 @@ function aes_gcm.decrypt(key, nonce, ciphertext_and_tag, aad) local keystream = generate_keystream(key, nonce, #ciphertext) local plaintext = "" for i = 1, #ciphertext do - plaintext = plaintext .. string_char(bit32_bxor(string_byte(ciphertext, i), string_byte(keystream, i))) + plaintext = plaintext .. string_char(bit32_raw_bxor(string_byte(ciphertext, i), string_byte(keystream, i))) end return plaintext diff --git a/src/noiseprotocol/crypto/blake2.lua b/src/noiseprotocol/crypto/blake2.lua index a4fd561..76eed60 100644 --- a/src/noiseprotocol/crypto/blake2.lua +++ b/src/noiseprotocol/crypto/blake2.lua @@ -12,13 +12,13 @@ local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op --- Local references for performance (avoid module table lookups in hot loops) -local bit32_add = bit32.add -local bit32_bxor = bit32.bxor -local bit32_ror = bit32.ror -local bit64_add = bit64.add -local bit64_xor = bit64.xor -local bit64_ror = bit64.ror +-- Local references for performance +local bit32_raw_add = bit32.raw_add +local bit32_raw_bxor = bit32.raw_bxor +local bit32_raw_ror = bit32.raw_ror +local bit64_raw_add = bit64.raw_add +local bit64_raw_bxor = bit64.raw_bxor +local bit64_raw_ror = bit64.raw_ror local bit64_new = bit64.new local string_byte = string.byte local string_char = string.char @@ -109,14 +109,14 @@ local blake2b_v = create_blake2b_vector() --- @param x integer Message word x --- @param y integer Message word y local function blake2s_g(v, a, b, c, d, x, y) - v[a] = bit32_add(bit32_add(v[a], v[b]), x) - v[d] = bit32_ror(bit32_bxor(v[d], v[a]), 16) - v[c] = bit32_add(v[c], v[d]) - v[b] = bit32_ror(bit32_bxor(v[b], v[c]), 12) - v[a] = bit32_add(bit32_add(v[a], v[b]), y) - v[d] = bit32_ror(bit32_bxor(v[d], v[a]), 8) - v[c] = bit32_add(v[c], v[d]) - v[b] = bit32_ror(bit32_bxor(v[b], v[c]), 7) + v[a] = bit32_raw_add(bit32_raw_add(v[a], v[b]), x) + v[d] = bit32_raw_ror(bit32_raw_bxor(v[d], v[a]), 16) + v[c] = bit32_raw_add(v[c], v[d]) + v[b] = bit32_raw_ror(bit32_raw_bxor(v[b], v[c]), 12) + v[a] = bit32_raw_add(bit32_raw_add(v[a], v[b]), y) + v[d] = bit32_raw_ror(bit32_raw_bxor(v[d], v[a]), 8) + v[c] = bit32_raw_add(v[c], v[d]) + v[b] = bit32_raw_ror(bit32_raw_bxor(v[b], v[c]), 7) end --- BLAKE2b G function @@ -128,14 +128,14 @@ end --- @param x table Message word x --- @param y table Message word y local function blake2b_g(v, a, b, c, d, x, y) - v[a] = bit64_add(bit64_add(v[a], v[b]), x) - v[d] = bit64_ror(bit64_xor(v[d], v[a]), 32) - v[c] = bit64_add(v[c], v[d]) - v[b] = bit64_ror(bit64_xor(v[b], v[c]), 24) - v[a] = bit64_add(bit64_add(v[a], v[b]), y) - v[d] = bit64_ror(bit64_xor(v[d], v[a]), 16) - v[c] = bit64_add(v[c], v[d]) - v[b] = bit64_ror(bit64_xor(v[b], v[c]), 63) + v[a] = bit64_raw_add(bit64_raw_add(v[a], v[b]), x) + v[d] = bit64_raw_ror(bit64_raw_bxor(v[d], v[a]), 32) + v[c] = bit64_raw_add(v[c], v[d]) + v[b] = bit64_raw_ror(bit64_raw_bxor(v[b], v[c]), 24) + v[a] = bit64_raw_add(bit64_raw_add(v[a], v[b]), y) + v[d] = bit64_raw_ror(bit64_raw_bxor(v[d], v[a]), 16) + v[c] = bit64_raw_add(v[c], v[d]) + v[b] = bit64_raw_ror(bit64_raw_bxor(v[b], v[c]), 63) end --- BLAKE2s compression function @@ -159,10 +159,10 @@ local function blake2s_compress(h, m, t, th, f) end -- Mix in counter and final flag - v[13] = bit32_bxor(v[13], t) -- Low 32 bits of counter - v[14] = bit32_bxor(v[14], th) -- High 32 bits of counter + v[13] = bit32_raw_bxor(v[13], t) -- Low 32 bits of counter + v[14] = bit32_raw_bxor(v[14], th) -- High 32 bits of counter if f then - v[15] = bit32_bxor(v[15], 0xFFFFFFFF) -- Invert all bits for final block + v[15] = bit32_raw_bxor(v[15], 0xFFFFFFFF) -- Invert all bits for final block end -- 10 rounds @@ -185,7 +185,7 @@ local function blake2s_compress(h, m, t, th, f) -- Finalize for i = 1, 8 do - h[i] = bit32_bxor(bit32_bxor(h[i], v[i]), v[i + 8]) + h[i] = bit32_raw_bxor(bit32_raw_bxor(h[i], v[i]), v[i + 8]) end end @@ -209,10 +209,10 @@ local function blake2b_compress(h, m, t, f) end -- Mix in counter and final flag - v[13] = bit64_xor(v[13], t) - v[14] = bit64_xor(v[14], bit64_new(0, 0)) -- High 64 bits of counter (always 0 for messages < 2^64 bytes) + v[13] = bit64_raw_bxor(v[13], t) + v[14] = bit64_raw_bxor(v[14], bit64_new(0, 0)) -- High 64 bits of counter (always 0 for messages < 2^64 bytes) if f then - v[15] = bit64_xor(v[15], bit64_new(0xffffffff, 0xffffffff)) + v[15] = bit64_raw_bxor(v[15], bit64_new(0xffffffff, 0xffffffff)) end -- 12 rounds @@ -235,7 +235,7 @@ local function blake2b_compress(h, m, t, f) -- Finalize for i = 1, 8 do - h[i] = bit64_xor(bit64_xor(h[i], v[i]), v[i + 8]) + h[i] = bit64_raw_bxor(bit64_raw_bxor(h[i], v[i]), v[i + 8]) end end @@ -260,7 +260,7 @@ function blake2.blake2s(data) -- Parameter block: digest length = 32, key length = 0, fanout = 1, depth = 1 -- All other parameters are 0 (no salt, no personalization, etc.) local param = 32 + (0 * 256) + (1 * 65536) + (1 * 16777216) -- 0x01010020 - h[1] = bit32_bxor(h[1], param) + h[1] = bit32_raw_bxor(h[1], param) local data_len = #data local offset = 1 @@ -347,7 +347,7 @@ function blake2.blake2b(data) -- In little-endian 64-bit: 0x0000000001010040 -- Split into two 32-bit words (little-endian): low=0x01010040, high=0x00000000 -- But our u64 format is {high, low}, so we need {0x00000000, 0x01010040} - h[1] = bit64_xor(h[1], bit64_new(0x00000000, 0x01010040)) + h[1] = bit64_raw_bxor(h[1], bit64_new(0x00000000, 0x01010040)) local data_len = #data local offset = 1 @@ -355,7 +355,7 @@ function blake2.blake2b(data) -- Process full 128-byte blocks while offset + 127 <= data_len do - counter = bit64_add(counter, bit64_new(0, 128)) + counter = bit64_raw_add(counter, bit64_new(0, 128)) -- Check if this is the last block local is_last_block = (offset + 128 > data_len) @@ -374,7 +374,7 @@ function blake2.blake2b(data) -- Process final block (if there's remaining data) local remaining = data_len - offset + 1 if remaining > 0 then - counter = bit64_add(counter, bit64_new(0, remaining)) + counter = bit64_raw_add(counter, bit64_new(0, remaining)) -- Pad final block with zeros local final_block = data:sub(offset) .. string_rep("\0", 128 - remaining) @@ -451,8 +451,8 @@ function blake2.hmac_blake2s(key, data) local opad_bytes = {} for i = 1, block_size do local byte = string_byte(key, i) - ipad_bytes[i] = string_char(bit32_bxor(byte, 0x36)) - opad_bytes[i] = string_char(bit32_bxor(byte, 0x5C)) + ipad_bytes[i] = string_char(bit32_raw_bxor(byte, 0x36)) + opad_bytes[i] = string_char(bit32_raw_bxor(byte, 0x5C)) end local ipad = table_concat(ipad_bytes) local opad = table_concat(opad_bytes) @@ -494,8 +494,8 @@ function blake2.hmac_blake2b(key, data) local opad_bytes = {} for i = 1, block_size do local byte = string_byte(key, i) - ipad_bytes[i] = string_char(bit32_bxor(byte, 0x36)) - opad_bytes[i] = string_char(bit32_bxor(byte, 0x5C)) + ipad_bytes[i] = string_char(bit32_raw_bxor(byte, 0x36)) + opad_bytes[i] = string_char(bit32_raw_bxor(byte, 0x5C)) end local ipad = table_concat(ipad_bytes) local opad = table_concat(opad_bytes) diff --git a/src/noiseprotocol/crypto/chacha20.lua b/src/noiseprotocol/crypto/chacha20.lua index d341787..176365e 100644 --- a/src/noiseprotocol/crypto/chacha20.lua +++ b/src/noiseprotocol/crypto/chacha20.lua @@ -10,10 +10,10 @@ local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op --- Local references for performance (avoid module table lookups in hot loops) -local bit32_add = bit32.add -local bit32_bxor = bit32.bxor -local bit32_rol = bit32.rol +-- Local references for performance +local bit32_raw_add = bit32.raw_add +local bit32_raw_bxor = bit32.raw_bxor +local bit32_raw_rol = bit32.raw_rol local floor = math.floor local min = math.min local string_byte = string.byte @@ -89,17 +89,17 @@ end --- @param c integer Index of third word --- @param d integer Index of fourth word local function quarter_round(state, a, b, c, d) - state[a] = bit32_add(state[a], state[b]) - state[d] = bit32_rol(bit32_bxor(state[d], state[a]), 16) + state[a] = bit32_raw_add(state[a], state[b]) + state[d] = bit32_raw_rol(bit32_raw_bxor(state[d], state[a]), 16) - state[c] = bit32_add(state[c], state[d]) - state[b] = bit32_rol(bit32_bxor(state[b], state[c]), 12) + state[c] = bit32_raw_add(state[c], state[d]) + state[b] = bit32_raw_rol(bit32_raw_bxor(state[b], state[c]), 12) - state[a] = bit32_add(state[a], state[b]) - state[d] = bit32_rol(bit32_bxor(state[d], state[a]), 8) + state[a] = bit32_raw_add(state[a], state[b]) + state[d] = bit32_raw_rol(bit32_raw_bxor(state[d], state[a]), 8) - state[c] = bit32_add(state[c], state[d]) - state[b] = bit32_rol(bit32_bxor(state[b], state[c]), 7) + state[c] = bit32_raw_add(state[c], state[d]) + state[b] = bit32_raw_rol(bit32_raw_bxor(state[b], state[c]), 7) end --- Generate one 64-byte block of ChaCha20 keystream @@ -170,7 +170,7 @@ local function chacha20_block(key, nonce, counter) -- Add original state to working state for i = 1, 16 do - working_state[i] = bit32_add(working_state[i], state[i]) + working_state[i] = bit32_raw_add(working_state[i], state[i]) end -- Convert state to byte string (little-endian) - optimized with local references @@ -206,7 +206,7 @@ function chacha20.crypt(key, nonce, plaintext, counter) for i = 1, block_size do local plaintext_byte = string_byte(plaintext, offset + i - 1) local keystream_byte = string_byte(keystream, i) - result_bytes[result_idx] = string_char(bit32_bxor(plaintext_byte, keystream_byte)) + result_bytes[result_idx] = string_char(bit32_raw_bxor(plaintext_byte, keystream_byte)) result_idx = result_idx + 1 end diff --git a/src/noiseprotocol/crypto/chacha20_poly1305.lua b/src/noiseprotocol/crypto/chacha20_poly1305.lua index 20e92c5..d0f8c4b 100644 --- a/src/noiseprotocol/crypto/chacha20_poly1305.lua +++ b/src/noiseprotocol/crypto/chacha20_poly1305.lua @@ -10,7 +10,7 @@ local benchmark_op = utils.benchmark.benchmark_op local chacha20 = require("noiseprotocol.crypto.chacha20") local poly1305 = require("noiseprotocol.crypto.poly1305") --- Local references for performance (avoid module table lookups in hot loops) +-- Local references for performance local string_char = string.char local string_rep = string.rep local string_sub = string.sub diff --git a/src/noiseprotocol/crypto/poly1305.lua b/src/noiseprotocol/crypto/poly1305.lua index 71e68ba..148abc3 100644 --- a/src/noiseprotocol/crypto/poly1305.lua +++ b/src/noiseprotocol/crypto/poly1305.lua @@ -9,9 +9,9 @@ local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op --- Local references for performance (avoid module table lookups in hot loops) -local bit32_band = bit32.band -local bit32_lshift = bit32.lshift +-- Local references for performance +local bit32_raw_band = bit32.raw_band +local bit32_raw_lshift = bit32.raw_lshift local floor = math.floor local string_byte = string.byte local string_char = string.char @@ -49,7 +49,7 @@ local function reduce_high_order_terms(prod, start_pos, end_pos) local bit_offset = excess_bits % 8 if bit_offset > 0 then - reduction_multiplier = bit32_lshift(reduction_multiplier, bit_offset) + reduction_multiplier = bit32_raw_lshift(reduction_multiplier, bit_offset) end -- Add reduced value to target position @@ -181,13 +181,13 @@ function poly1305.authenticate(key, msg) -- Apply RFC 7539 clamping to ensure r has specific bit patterns -- This prevents certain classes of attacks and ensures key validity - r[4] = bit32_band(r[4], 15) -- Clear top 4 bits of 4th byte - r[5] = bit32_band(r[5], 252) -- Clear bottom 2 bits of 5th byte - r[8] = bit32_band(r[8], 15) -- Clear top 4 bits of 8th byte - r[9] = bit32_band(r[9], 252) -- Clear bottom 2 bits of 9th byte - r[12] = bit32_band(r[12], 15) -- Clear top 4 bits of 12th byte - r[13] = bit32_band(r[13], 252) -- Clear bottom 2 bits of 13th byte - r[16] = bit32_band(r[16], 15) -- Clear top 4 bits of 16th byte + r[4] = bit32_raw_band(r[4], 15) -- Clear top 4 bits of 4th byte + r[5] = bit32_raw_band(r[5], 252) -- Clear bottom 2 bits of 5th byte + r[8] = bit32_raw_band(r[8], 15) -- Clear top 4 bits of 8th byte + r[9] = bit32_raw_band(r[9], 252) -- Clear bottom 2 bits of 9th byte + r[12] = bit32_raw_band(r[12], 15) -- Clear top 4 bits of 12th byte + r[13] = bit32_raw_band(r[13], 252) -- Clear bottom 2 bits of 13th byte + r[16] = bit32_raw_band(r[16], 15) -- Clear top 4 bits of 16th byte -- Extract s (second 16 bytes) - used for final addition local s = create_key_array(key_bytes, 17) diff --git a/src/noiseprotocol/crypto/sha256.lua b/src/noiseprotocol/crypto/sha256.lua index 938b52f..96a8625 100644 --- a/src/noiseprotocol/crypto/sha256.lua +++ b/src/noiseprotocol/crypto/sha256.lua @@ -10,13 +10,13 @@ local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op --- Local references for performance (avoid global/module table lookups in hot loops) -local bit32_bxor = bit32.bxor -local bit32_band = bit32.band -local bit32_bnot = bit32.bnot -local bit32_ror = bit32.ror -local bit32_rshift = bit32.rshift -local bit32_add = bit32.add +-- Local references for performance +local bit32_raw_bxor = bit32.raw_bxor +local bit32_raw_band = bit32.raw_band +local bit32_raw_bnot = bit32.raw_bnot +local bit32_raw_ror = bit32.raw_ror +local bit32_raw_rshift = bit32.raw_rshift +local bit32_raw_add = bit32.raw_add local bytes_be_bytes_to_u32 = bytes.be_bytes_to_u32 local bytes_u32_to_be_bytes = bytes.u32_to_be_bytes local string_char = string.char @@ -137,9 +137,9 @@ local function sha256_chunk(chunk, H) for i = 17, 64 do local w15 = W[i - 15] local w2 = W[i - 2] - local s0 = bit32_bxor(bit32_ror(w15, 7), bit32_bxor(bit32_ror(w15, 18), bit32_rshift(w15, 3))) - local s1 = bit32_bxor(bit32_ror(w2, 17), bit32_bxor(bit32_ror(w2, 19), bit32_rshift(w2, 10))) - W[i] = bit32_add(bit32_add(bit32_add(W[i - 16], s0), W[i - 7]), s1) + local s0 = bit32_raw_bxor(bit32_raw_ror(w15, 7), bit32_raw_bxor(bit32_raw_ror(w15, 18), bit32_raw_rshift(w15, 3))) + local s1 = bit32_raw_bxor(bit32_raw_ror(w2, 17), bit32_raw_bxor(bit32_raw_ror(w2, 19), bit32_raw_rshift(w2, 10))) + W[i] = bit32_raw_add(bit32_raw_add(bit32_raw_add(W[i - 16], s0), W[i - 7]), s1) end -- Initialize working variables @@ -148,32 +148,32 @@ local function sha256_chunk(chunk, H) -- Main loop (optimized with local references) for i = 1, 64 do local ki = K[i] - local S1 = bit32_bxor(bit32_ror(e, 6), bit32_bxor(bit32_ror(e, 11), bit32_ror(e, 25))) - local ch = bit32_bxor(bit32_band(e, f), bit32_band(bit32_bnot(e), g)) - local temp1 = bit32_add(bit32_add(bit32_add(bit32_add(h, S1), ch), ki), W[i]) - local S0 = bit32_bxor(bit32_ror(a, 2), bit32_bxor(bit32_ror(a, 13), bit32_ror(a, 22))) - local maj = bit32_bxor(bit32_band(a, b), bit32_bxor(bit32_band(a, c), bit32_band(b, c))) - local temp2 = bit32_add(S0, maj) + local S1 = bit32_raw_bxor(bit32_raw_ror(e, 6), bit32_raw_bxor(bit32_raw_ror(e, 11), bit32_raw_ror(e, 25))) + local ch = bit32_raw_bxor(bit32_raw_band(e, f), bit32_raw_band(bit32_raw_bnot(e), g)) + local temp1 = bit32_raw_add(bit32_raw_add(bit32_raw_add(bit32_raw_add(h, S1), ch), ki), W[i]) + local S0 = bit32_raw_bxor(bit32_raw_ror(a, 2), bit32_raw_bxor(bit32_raw_ror(a, 13), bit32_raw_ror(a, 22))) + local maj = bit32_raw_bxor(bit32_raw_band(a, b), bit32_raw_bxor(bit32_raw_band(a, c), bit32_raw_band(b, c))) + local temp2 = bit32_raw_add(S0, maj) h = g g = f f = e - e = bit32_add(d, temp1) + e = bit32_raw_add(d, temp1) d = c c = b b = a - a = bit32_add(temp1, temp2) + a = bit32_raw_add(temp1, temp2) end -- Add compressed chunk to current hash value - H[1] = bit32_add(H[1], a) - H[2] = bit32_add(H[2], b) - H[3] = bit32_add(H[3], c) - H[4] = bit32_add(H[4], d) - H[5] = bit32_add(H[5], e) - H[6] = bit32_add(H[6], f) - H[7] = bit32_add(H[7], g) - H[8] = bit32_add(H[8], h) + H[1] = bit32_raw_add(H[1], a) + H[2] = bit32_raw_add(H[2], b) + H[3] = bit32_raw_add(H[3], c) + H[4] = bit32_raw_add(H[4], d) + H[5] = bit32_raw_add(H[5], e) + H[6] = bit32_raw_add(H[6], f) + H[7] = bit32_raw_add(H[7], g) + H[8] = bit32_raw_add(H[8], h) end -- ============================================================================ @@ -267,8 +267,8 @@ function sha256.hmac_sha256(key, data) local opad_bytes = {} for i = 1, block_size do local byte = string_byte(key, i) - ipad_bytes[i] = string_char(bit32_bxor(byte, 0x36)) - opad_bytes[i] = string_char(bit32_bxor(byte, 0x5C)) + ipad_bytes[i] = string_char(bit32_raw_bxor(byte, 0x36)) + opad_bytes[i] = string_char(bit32_raw_bxor(byte, 0x5C)) end local ipad = table_concat(ipad_bytes) local opad = table_concat(opad_bytes) diff --git a/src/noiseprotocol/crypto/sha512.lua b/src/noiseprotocol/crypto/sha512.lua index 0cad989..faede7e 100644 --- a/src/noiseprotocol/crypto/sha512.lua +++ b/src/noiseprotocol/crypto/sha512.lua @@ -12,15 +12,15 @@ local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op --- Local references for performance (avoid module table lookups in hot loops) -local bit64_add = bit64.add -local bit64_xor = bit64.xor -local bit64_band = bit64.band -local bit64_bnot = bit64.bnot -local bit64_ror = bit64.ror -local bit64_shr = bit64.shr +-- Local references for performance +local bit64_raw_add = bit64.raw_add +local bit64_raw_bxor = bit64.raw_bxor +local bit64_raw_band = bit64.raw_band +local bit64_raw_bnot = bit64.raw_bnot +local bit64_raw_ror = bit64.raw_ror +local bit64_raw_rshift = bit64.raw_rshift local bit64_new = bit64.new -local bit32_bxor = bit32.bxor +local bit32_raw_bxor = bit32.raw_bxor local string_char = string.char local string_rep = string.rep local string_byte = string.byte @@ -147,28 +147,28 @@ local chunk_W = create_message_schedule_64() --- @param x Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function Sigma0(x) - return bit64_xor(bit64_xor(bit64_ror(x, 28), bit64_ror(x, 34)), bit64_ror(x, 39)) + return bit64_raw_bxor(bit64_raw_bxor(bit64_raw_ror(x, 28), bit64_raw_ror(x, 34)), bit64_raw_ror(x, 39)) end --- SHA-512 Sigma1 function --- @param x Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function Sigma1(x) - return bit64_xor(bit64_xor(bit64_ror(x, 14), bit64_ror(x, 18)), bit64_ror(x, 41)) + return bit64_raw_bxor(bit64_raw_bxor(bit64_raw_ror(x, 14), bit64_raw_ror(x, 18)), bit64_raw_ror(x, 41)) end --- SHA-512 sigma0 function --- @param x Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function sigma0(x) - return bit64_xor(bit64_xor(bit64_ror(x, 1), bit64_ror(x, 8)), bit64_shr(x, 7)) + return bit64_raw_bxor(bit64_raw_bxor(bit64_raw_ror(x, 1), bit64_raw_ror(x, 8)), bit64_raw_rshift(x, 7)) end --- SHA-512 sigma1 function --- @param x Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function sigma1(x) - return bit64_xor(bit64_xor(bit64_ror(x, 19), bit64_ror(x, 61)), bit64_shr(x, 6)) + return bit64_raw_bxor(bit64_raw_bxor(bit64_raw_ror(x, 19), bit64_raw_ror(x, 61)), bit64_raw_rshift(x, 6)) end --- SHA-512 Ch function @@ -177,7 +177,7 @@ end --- @param z Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function Ch(x, y, z) - return bit64_xor(bit64_band(x, y), bit64_band(bit64_bnot(x), z)) + return bit64_raw_bxor(bit64_raw_band(x, y), bit64_raw_band(bit64_raw_bnot(x), z)) end --- SHA-512 Maj function @@ -186,7 +186,7 @@ end --- @param z Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function Maj(x, y, z) - return bit64_xor(bit64_xor(bit64_band(x, y), bit64_band(x, z)), bit64_band(y, z)) + return bit64_raw_bxor(bit64_raw_bxor(bit64_raw_band(x, y), bit64_raw_band(x, z)), bit64_raw_band(y, z)) end --- SHA-512 core compression function @@ -208,7 +208,7 @@ local function sha512_chunk(chunk, H) local w2 = W[i - 2] local s0 = sigma0(w15) local s1 = sigma1(w2) - local result = bit64_add(bit64_add(bit64_add(W[i - 16], s0), W[i - 7]), s1) + local result = bit64_raw_add(bit64_raw_add(bit64_raw_add(W[i - 16], s0), W[i - 7]), s1) W[i][1], W[i][2] = result[1], result[2] end @@ -220,30 +220,30 @@ local function sha512_chunk(chunk, H) local prime = K[i] local S1 = Sigma1(e) local ch = Ch(e, f, g) - local temp1 = bit64_add(bit64_add(bit64_add(bit64_add(h, S1), ch), prime), W[i]) + local temp1 = bit64_raw_add(bit64_raw_add(bit64_raw_add(bit64_raw_add(h, S1), ch), prime), W[i]) local S0 = Sigma0(a) local maj = Maj(a, b, c) - local temp2 = bit64_add(S0, maj) + local temp2 = bit64_raw_add(S0, maj) h = g g = f f = e - e = bit64_add(d, temp1) + e = bit64_raw_add(d, temp1) d = c c = b b = a - a = bit64_add(temp1, temp2) + a = bit64_raw_add(temp1, temp2) end -- Add compressed chunk to current hash value - H[1] = bit64_add(H[1], a) - H[2] = bit64_add(H[2], b) - H[3] = bit64_add(H[3], c) - H[4] = bit64_add(H[4], d) - H[5] = bit64_add(H[5], e) - H[6] = bit64_add(H[6], f) - H[7] = bit64_add(H[7], g) - H[8] = bit64_add(H[8], h) + H[1] = bit64_raw_add(H[1], a) + H[2] = bit64_raw_add(H[2], b) + H[3] = bit64_raw_add(H[3], c) + H[4] = bit64_raw_add(H[4], d) + H[5] = bit64_raw_add(H[5], e) + H[6] = bit64_raw_add(H[6], f) + H[7] = bit64_raw_add(H[7], g) + H[8] = bit64_raw_add(H[8], h) end --- Compute SHA-512 hash of input data @@ -338,8 +338,8 @@ function sha512.hmac_sha512(key, data) local opad_bytes = {} for i = 1, block_size do local byte = string_byte(key, i) - ipad_bytes[i] = string_char(bit32_bxor(byte, 0x36)) - opad_bytes[i] = string_char(bit32_bxor(byte, 0x5C)) + ipad_bytes[i] = string_char(bit32_raw_bxor(byte, 0x36)) + opad_bytes[i] = string_char(bit32_raw_bxor(byte, 0x5C)) end local ipad = table_concat(ipad_bytes) local opad = table_concat(opad_bytes) diff --git a/src/noiseprotocol/crypto/x25519.lua b/src/noiseprotocol/crypto/x25519.lua index be1073b..b6fe347 100644 --- a/src/noiseprotocol/crypto/x25519.lua +++ b/src/noiseprotocol/crypto/x25519.lua @@ -9,17 +9,19 @@ local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op --- ============================================================================ --- CURVE25519 FIELD ARITHMETIC --- ============================================================================ - --- Local references for performance (avoid global table lookups in hot loops) +-- Local references for performance +local bit32_raw_band = bit32.raw_band +local bit32_raw_rshift = bit32.raw_rshift local floor = math.floor local string_byte = string.byte local string_char = string.char local string_rep = string.rep local table_concat = table.concat +-- ============================================================================ +-- CURVE25519 FIELD ARITHMETIC +-- ============================================================================ + --- @alias FieldElement integer[] 16-element array (indices 1-16) representing a field element --- @alias ProductArray integer[] 31-element array (indices 1-31) for multiplication products --- @alias ScalarArray integer[] 32-element array (indices 1-32) for scalar bytes @@ -252,7 +254,7 @@ local function scalarmult(out, scalar, point) -- Optimized bit extraction local byte_idx = floor(i * 0.125) + 1 -- i / 8 + 1 local bit_idx = i % 8 - local bit = bit32.band(bit32.rshift(clam[byte_idx], bit_idx), 1) + local bit = bit32_raw_band(bit32_raw_rshift(clam[byte_idx], bit_idx), 1) swap(a, b, bit) swap(c, d, bit) add(e, a, c) diff --git a/src/noiseprotocol/crypto/x448.lua b/src/noiseprotocol/crypto/x448.lua index 6f343f9..d0844d2 100644 --- a/src/noiseprotocol/crypto/x448.lua +++ b/src/noiseprotocol/crypto/x448.lua @@ -13,17 +13,19 @@ local x448 = {} local bitn = require("bitn") -local band = bitn.bit32.band -local bor = bitn.bit32.bor -local bxor = bitn.bit32.bxor -local rshift = bitn.bit32.rshift - local utils = require("noiseprotocol.utils") + local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op -local floor = math.floor -local char = string.char + +-- Local references for performance +local bit32_raw_band = bitn.bit32.raw_band +local bit32_raw_bor = bitn.bit32.raw_bor +local bit32_raw_bxor = bitn.bit32.raw_bxor +local bit32_raw_rshift = bitn.bit32.raw_rshift local byte = string.byte +local char = string.char +local floor = math.floor local string_rep = string.rep local table_concat = table.concat @@ -70,7 +72,7 @@ local function fe_reduce(a) local carry = 0 for i = 1, NUM_LIMBS do carry = carry + (a[i] or 0) - a[i] = band(carry, LIMB_MASK) + a[i] = bit32_raw_band(carry, LIMB_MASK) carry = floor(carry / 256) end @@ -83,7 +85,7 @@ local function fe_reduce(a) local new_carry = 0 for i = 1, NUM_LIMBS do new_carry = new_carry + a[i] - a[i] = band(new_carry, LIMB_MASK) + a[i] = bit32_raw_band(new_carry, LIMB_MASK) new_carry = floor(new_carry / 256) end carry = new_carry @@ -128,17 +130,17 @@ local function fe_sub(a, b) local carry = 0 for i = 1, 28 do local sum = r[i] + 0xFF + carry - r[i] = band(sum, LIMB_MASK) + r[i] = bit32_raw_band(sum, LIMB_MASK) carry = floor(sum / 256) end local sum = r[29] + 0xFE + carry - r[29] = band(sum, LIMB_MASK) + r[29] = bit32_raw_band(sum, LIMB_MASK) carry = floor(sum / 256) for i = 30, NUM_LIMBS do sum = r[i] + 0xFF + carry - r[i] = band(sum, LIMB_MASK) + r[i] = bit32_raw_band(sum, LIMB_MASK) carry = floor(sum / 256) end end @@ -177,7 +179,7 @@ local function fe_mul(a, b) local carry = 0 for i = 1, 2 * NUM_LIMBS do local sum = r[i] + carry - r[i] = band(sum, LIMB_MASK) + r[i] = bit32_raw_band(sum, LIMB_MASK) carry = floor(sum / 256) end @@ -213,7 +215,7 @@ local function fe_mul(a, b) carry = 0 for i = 1, NUM_LIMBS do local sum = r[i] + carry - r[i] = band(sum, LIMB_MASK) + r[i] = bit32_raw_band(sum, LIMB_MASK) carry = floor(sum / 256) end @@ -225,7 +227,7 @@ local function fe_mul(a, b) carry = 0 for i = 1, NUM_LIMBS do local sum = r[i] + carry - r[i] = band(sum, LIMB_MASK) + r[i] = bit32_raw_band(sum, LIMB_MASK) carry = floor(sum / 256) end end @@ -356,7 +358,7 @@ local function fe_tobytes(a) -- Convert to bytes - with 8-bit limbs it's direct local b = {} for i = 1, NUM_LIMBS do - b[i] = char(band(t[i] or 0, 0xFF)) + b[i] = char(bit32_raw_band(t[i] or 0, 0xFF)) end return table_concat(b) @@ -375,8 +377,8 @@ local function x448_scalarmult(scalar, base) for i = 1, 56 do k[i] = byte(scalar, i) or 0 end - k[1] = band(k[1], 252) -- Clear low 2 bits - k[56] = bor(k[56], 128) -- Set high bit + k[1] = bit32_raw_band(k[1], 252) -- Clear low 2 bits + k[56] = bit32_raw_bor(k[56], 128) -- Set high bit -- Initialize Montgomery ladder local x_1 = fe_copy(u) @@ -388,12 +390,12 @@ local function x448_scalarmult(scalar, base) -- Montgomery ladder for t = 447, 0, -1 do - local byte_idx = rshift(t, 3) + 1 -- t // 8 + 1 - local bit_idx = band(t, 7) -- t % 8 - local kt = band(rshift(k[byte_idx], bit_idx), 1) + local byte_idx = bit32_raw_rshift(t, 3) + 1 -- t // 8 + 1 + local bit_idx = bit32_raw_band(t, 7) -- t % 8 + local kt = bit32_raw_band(bit32_raw_rshift(k[byte_idx], bit_idx), 1) -- Conditional swap - swap = bxor(swap, kt) + swap = bit32_raw_bxor(swap, kt) x_2, x_3 = cswap(swap, x_2, x_3) z_2, z_3 = cswap(swap, z_2, z_3) swap = kt @@ -415,8 +417,8 @@ local function x448_scalarmult(scalar, base) -- z_2 = e * (aa + a24 * e) local a24_limbs = fe_zero() - a24_limbs[1] = band(A24, 0xFF) - a24_limbs[2] = band(rshift(A24, 8), 0xFF) + a24_limbs[1] = bit32_raw_band(A24, 0xFF) + a24_limbs[2] = bit32_raw_band(bit32_raw_rshift(A24, 8), 0xFF) local a24_e = fe_mul(a24_limbs, e) z_2 = fe_mul(e, fe_add(aa, a24_e)) diff --git a/src/noiseprotocol/utils/bytes.lua b/src/noiseprotocol/utils/bytes.lua index abfe2aa..c392c6a 100644 --- a/src/noiseprotocol/utils/bytes.lua +++ b/src/noiseprotocol/utils/bytes.lua @@ -7,12 +7,24 @@ local bitn = require("bitn") local bit32 = bitn.bit32 local bit64 = bitn.bit64 +-- Local references for performance +local bit32_mask = bit32.mask +local bit32_raw_bor = bit32.raw_bor +local bit32_raw_bxor = bit32.raw_bxor +local bit64_new = bit64.new +local floor = math.floor +local string_byte = string.byte +local string_char = string.char +local string_format = string.format +local string_rep = string.rep +local table_concat = table.concat + --- Convert binary string to hexadecimal string --- @param str string Binary string --- @return string hex Hexadecimal representation function bytes.to_hex(str) return (str:gsub(".", function(c) - return string.format("%02x", string.byte(c)) + return string_format("%02x", string_byte(c)) end)) end @@ -21,7 +33,7 @@ end --- @return string str Binary string function bytes.from_hex(hex) return (hex:gsub("..", function(cc) - return string.char(tonumber(cc, 16)) + return string_char(tonumber(cc, 16)) end)) end @@ -29,16 +41,16 @@ end --- @param n integer 32-bit unsigned integer --- @return string bytes 4-byte string in little-endian order function bytes.u32_to_le_bytes(n) - n = bit32.mask(n) - return string.char(n % 256, math.floor(n / 256) % 256, math.floor(n / 65536) % 256, math.floor(n / 16777216) % 256) + n = bit32_mask(n) + return string_char(n % 256, floor(n / 256) % 256, floor(n / 65536) % 256, floor(n / 16777216) % 256) end --- Convert 32-bit unsigned integer to 4 bytes (big-endian) --- @param n integer 32-bit unsigned integer --- @return string bytes 4-byte string in big-endian order function bytes.u32_to_be_bytes(n) - n = bit32.mask(n) - return string.char(math.floor(n / 16777216) % 256, math.floor(n / 65536) % 256, math.floor(n / 256) % 256, n % 256) + n = bit32_mask(n) + return string_char(floor(n / 16777216) % 256, floor(n / 65536) % 256, floor(n / 256) % 256, n % 256) end --- Convert 64-bit value to 8 bytes (big-endian) @@ -56,7 +68,7 @@ function bytes.u64_to_le_bytes(x) -- Handle simple integer case (< 2^53) if type(x) == "number" then local low = x % 0x100000000 - local high = math.floor(x / 0x100000000) + local high = floor(x / 0x100000000) return bytes.u32_to_le_bytes(low) .. bytes.u32_to_le_bytes(high) else -- Handle {high, low} pair @@ -72,7 +84,7 @@ end function bytes.le_bytes_to_u32(str, offset) offset = offset or 1 assert(#str >= offset + 3, "Insufficient bytes for u32") - local b1, b2, b3, b4 = string.byte(str, offset, offset + 3) + local b1, b2, b3, b4 = string_byte(str, offset, offset + 3) return b1 + b2 * 256 + b3 * 65536 + b4 * 16777216 end @@ -83,7 +95,7 @@ end function bytes.be_bytes_to_u32(str, offset) offset = offset or 1 assert(#str >= offset + 3, "Insufficient bytes for u32") - local b1, b2, b3, b4 = string.byte(str, offset, offset + 3) + local b1, b2, b3, b4 = string_byte(str, offset, offset + 3) return b1 * 16777216 + b2 * 65536 + b3 * 256 + b4 end @@ -118,10 +130,11 @@ end function bytes.xor_bytes(a, b) assert(#a == #b, "Strings must be same length for XOR") local result = {} + -- Using raw_bxor for performance; XOR on bytes (0-255) is always safe for i = 1, #a do - result[i] = string.char(bit32.bxor(string.byte(a, i), string.byte(b, i))) + result[i] = string_char(bit32_raw_bxor(string_byte(a, i), string_byte(b, i))) end - return table.concat(result) + return table_concat(result) end --- Constant-time comparison of two strings @@ -134,7 +147,7 @@ function bytes.constant_time_compare(a, b) end local result = 0 for i = 1, #a do - result = bit32.bor(result, bit32.bxor(string.byte(a, i), string.byte(b, i))) + result = bit32_raw_bor(result, bit32_raw_bxor(string_byte(a, i), string_byte(b, i))) end return result == 0 end @@ -148,7 +161,7 @@ function bytes.pad_to_16(data) if padding_len == 0 then return data end - return data .. string.rep("\0", padding_len) + return data .. string_rep("\0", padding_len) end --- Run comprehensive self-test with test vectors @@ -181,7 +194,7 @@ function bytes.selftest() { name = "hex - single byte min", test = function() - local data = string.char(0x00) + local data = string_char(0x00) local hex = bytes.to_hex(data) return hex == "00" end, @@ -189,7 +202,7 @@ function bytes.selftest() { name = "hex - single byte max", test = function() - local data = string.char(0xFF) + local data = string_char(0xFF) local hex = bytes.to_hex(data) return hex == "ff" end, @@ -198,7 +211,7 @@ function bytes.selftest() name = "hex - all byte values", test = function() -- Test a few representative byte values - local data = string.char(0x00, 0x01, 0x7F, 0x80, 0xFE, 0xFF) + local data = string_char(0x00, 0x01, 0x7F, 0x80, 0xFE, 0xFF) local hex = bytes.to_hex(data) return hex == "00017f80feff" end, @@ -214,7 +227,7 @@ function bytes.selftest() { name = "hex - binary data", test = function() - local data = string.char(0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0) + local data = string_char(0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0) local hex = bytes.to_hex(data) local back = bytes.from_hex(hex) return hex == "8090a0b0c0d0e0f0" and back == data @@ -228,7 +241,7 @@ function bytes.selftest() local n = 0x12345678 local bytes_str = bytes.u32_to_le_bytes(n) local back = bytes.le_bytes_to_u32(bytes_str) - local b1, b2, b3, b4 = string.byte(bytes_str, 1, 4) + local b1, b2, b3, b4 = string_byte(bytes_str, 1, 4) return back == n and b1 == 0x78 and b2 == 0x56 and b3 == 0x34 and b4 == 0x12 end, }, @@ -238,7 +251,7 @@ function bytes.selftest() local n = 0 local bytes_str = bytes.u32_to_le_bytes(n) local back = bytes.le_bytes_to_u32(bytes_str) - return back == 0 and bytes_str == string.char(0, 0, 0, 0) + return back == 0 and bytes_str == string_char(0, 0, 0, 0) end, }, { @@ -247,7 +260,7 @@ function bytes.selftest() local n = 0xFFFFFFFF local bytes_str = bytes.u32_to_le_bytes(n) local back = bytes.le_bytes_to_u32(bytes_str) - return back == 0xFFFFFFFF and bytes_str == string.char(0xFF, 0xFF, 0xFF, 0xFF) + return back == 0xFFFFFFFF and bytes_str == string_char(0xFF, 0xFF, 0xFF, 0xFF) end, }, { @@ -255,7 +268,7 @@ function bytes.selftest() test = function() local n = 0x100000000 -- Should be masked to 0 local bytes_str = bytes.u32_to_le_bytes(n) - return bytes_str == string.char(0, 0, 0, 0) + return bytes_str == string_char(0, 0, 0, 0) end, }, { @@ -264,13 +277,13 @@ function bytes.selftest() local n = 0x80000000 local bytes_str = bytes.u32_to_le_bytes(n) local back = bytes.le_bytes_to_u32(bytes_str) - return back == 0x80000000 and bytes_str == string.char(0, 0, 0, 0x80) + return back == 0x80000000 and bytes_str == string_char(0, 0, 0, 0x80) end, }, { name = "u32 LE - with offset", test = function() - local data = "XXX" .. string.char(0x78, 0x56, 0x34, 0x12) .. "YYY" + local data = "XXX" .. string_char(0x78, 0x56, 0x34, 0x12) .. "YYY" local n = bytes.le_bytes_to_u32(data, 4) return n == 0x12345678 end, @@ -281,7 +294,7 @@ function bytes.selftest() local n = 0x12345678 local bytes_str = bytes.u32_to_be_bytes(n) local back = bytes.be_bytes_to_u32(bytes_str) - local b1, b2, b3, b4 = string.byte(bytes_str, 1, 4) + local b1, b2, b3, b4 = string_byte(bytes_str, 1, 4) return back == n and b1 == 0x12 and b2 == 0x34 and b3 == 0x56 and b4 == 0x78 end, }, @@ -291,7 +304,7 @@ function bytes.selftest() local n = 0 local bytes_str = bytes.u32_to_be_bytes(n) local back = bytes.be_bytes_to_u32(bytes_str) - return back == 0 and bytes_str == string.char(0, 0, 0, 0) + return back == 0 and bytes_str == string_char(0, 0, 0, 0) end, }, { @@ -300,13 +313,13 @@ function bytes.selftest() local n = 0xFFFFFFFF local bytes_str = bytes.u32_to_be_bytes(n) local back = bytes.be_bytes_to_u32(bytes_str) - return back == 0xFFFFFFFF and bytes_str == string.char(0xFF, 0xFF, 0xFF, 0xFF) + return back == 0xFFFFFFFF and bytes_str == string_char(0xFF, 0xFF, 0xFF, 0xFF) end, }, { name = "u32 BE - with offset", test = function() - local data = "XXX" .. string.char(0x12, 0x34, 0x56, 0x78) .. "YYY" + local data = "XXX" .. string_char(0x12, 0x34, 0x56, 0x78) .. "YYY" local n = bytes.be_bytes_to_u32(data, 4) return n == 0x12345678 end, @@ -316,10 +329,10 @@ function bytes.selftest() { name = "u64 LE - basic table", test = function() - local n = bit64.new(0x12345678, 0x9ABCDEF0) + local n = bit64_new(0x12345678, 0x9ABCDEF0) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) - local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(bytes_str, 1, 8) + local b1, b2, b3, b4, b5, b6, b7, b8 = string_byte(bytes_str, 1, 8) return back[1] == n[1] and back[2] == n[2] and b1 == 0xF0 @@ -340,32 +353,32 @@ function bytes.selftest() local back = bytes.le_bytes_to_u64(bytes_str) -- Check the conversion worked correctly local expected_low = n % 0x100000000 - local expected_high = math.floor(n / 0x100000000) + local expected_high = floor(n / 0x100000000) return back[1] == expected_high and back[2] == expected_low end, }, { name = "u64 LE - zero", test = function() - local n = bit64.new(0, 0) + local n = bit64_new(0, 0) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) - return back[1] == 0 and back[2] == 0 and bytes_str == string.rep(string.char(0), 8) + return back[1] == 0 and back[2] == 0 and bytes_str == string_rep(string_char(0), 8) end, }, { name = "u64 LE - max value", test = function() - local n = bit64.new(0xFFFFFFFF, 0xFFFFFFFF) + local n = bit64_new(0xFFFFFFFF, 0xFFFFFFFF) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) - return back[1] == 0xFFFFFFFF and back[2] == 0xFFFFFFFF and bytes_str == string.rep(string.char(0xFF), 8) + return back[1] == 0xFFFFFFFF and back[2] == 0xFFFFFFFF and bytes_str == string_rep(string_char(0xFF), 8) end, }, { name = "u64 LE - high word only", test = function() - local n = bit64.new(0x12345678, 0) + local n = bit64_new(0x12345678, 0) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) return back[1] == 0x12345678 and back[2] == 0 @@ -374,7 +387,7 @@ function bytes.selftest() { name = "u64 LE - low word only", test = function() - local n = bit64.new(0, 0x12345678) + local n = bit64_new(0, 0x12345678) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) return back[1] == 0 and back[2] == 0x12345678 @@ -383,7 +396,7 @@ function bytes.selftest() { name = "u64 LE - with offset", test = function() - local data = "XXX" .. bytes.u64_to_le_bytes(bit64.new(0x12345678, 0x9ABCDEF0)) .. "YYY" + local data = "XXX" .. bytes.u64_to_le_bytes(bit64_new(0x12345678, 0x9ABCDEF0)) .. "YYY" local n = bytes.le_bytes_to_u64(data, 4) return n[1] == 0x12345678 and n[2] == 0x9ABCDEF0 end, @@ -391,10 +404,10 @@ function bytes.selftest() { name = "u64 BE - basic", test = function() - local n = bit64.new(0x12345678, 0x9ABCDEF0) + local n = bit64_new(0x12345678, 0x9ABCDEF0) local bytes_str = bytes.u64_to_be_bytes(n) local back = bytes.be_bytes_to_u64(bytes_str) - local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(bytes_str, 1, 8) + local b1, b2, b3, b4, b5, b6, b7, b8 = string_byte(bytes_str, 1, 8) return back[1] == n[1] and back[2] == n[2] and b1 == 0x12 @@ -410,16 +423,16 @@ function bytes.selftest() { name = "u64 BE - zero", test = function() - local n = bit64.new(0, 0) + local n = bit64_new(0, 0) local bytes_str = bytes.u64_to_be_bytes(n) local back = bytes.be_bytes_to_u64(bytes_str) - return back[1] == 0 and back[2] == 0 and bytes_str == string.rep(string.char(0), 8) + return back[1] == 0 and back[2] == 0 and bytes_str == string_rep(string_char(0), 8) end, }, { name = "u64 BE - with offset", test = function() - local data = "XXX" .. bytes.u64_to_be_bytes(bit64.new(0x12345678, 0x9ABCDEF0)) .. "YYY" + local data = "XXX" .. bytes.u64_to_be_bytes(bit64_new(0x12345678, 0x9ABCDEF0)) .. "YYY" local n = bytes.be_bytes_to_u64(data, 4) return n[1] == 0x12345678 and n[2] == 0x9ABCDEF0 end, @@ -429,10 +442,10 @@ function bytes.selftest() { name = "xor - basic", test = function() - local a = string.char(0x01, 0x02, 0x03, 0x04) - local b = string.char(0xFF, 0xFE, 0xFD, 0xFC) + local a = string_char(0x01, 0x02, 0x03, 0x04) + local b = string_char(0xFF, 0xFE, 0xFD, 0xFC) local result = bytes.xor_bytes(a, b) - local r1, r2, r3, r4 = string.byte(result, 1, 4) + local r1, r2, r3, r4 = string_byte(result, 1, 4) return r1 == 0xFE and r2 == 0xFC and r3 == 0xFE and r4 == 0xF8 end, }, @@ -448,10 +461,10 @@ function bytes.selftest() { name = "xor - single byte", test = function() - local a = string.char(0x00) - local b = string.char(0xFF) + local a = string_char(0x00) + local b = string_char(0xFF) local result = bytes.xor_bytes(a, b) - return result == string.char(0xFF) + return result == string_char(0xFF) end, }, { @@ -459,23 +472,23 @@ function bytes.selftest() test = function() local a = "test" local result = bytes.xor_bytes(a, a) - return result == string.char(0, 0, 0, 0) + return result == string_char(0, 0, 0, 0) end, }, { name = "xor - all zeros pattern", test = function() - local a = string.char(0xAA, 0xBB, 0xCC, 0xDD) - local b = string.char(0xAA, 0xBB, 0xCC, 0xDD) + local a = string_char(0xAA, 0xBB, 0xCC, 0xDD) + local b = string_char(0xAA, 0xBB, 0xCC, 0xDD) local result = bytes.xor_bytes(a, b) - return result == string.char(0, 0, 0, 0) + return result == string_char(0, 0, 0, 0) end, }, { name = "xor - identity with zeros", test = function() - local a = string.char(0x12, 0x34, 0x56, 0x78) - local b = string.char(0, 0, 0, 0) + local a = string_char(0x12, 0x34, 0x56, 0x78) + local b = string_char(0, 0, 0, 0) local result = bytes.xor_bytes(a, b) return result == a end, @@ -533,8 +546,8 @@ function bytes.selftest() { name = "constant_time_compare - binary with nulls", test = function() - local a = string.char(0x00, 0x01, 0xFF) - local b = string.char(0x00, 0x01, 0xFF) + local a = string_char(0x00, 0x01, 0xFF) + local b = string_char(0x00, 0x01, 0xFF) return bytes.constant_time_compare(a, b) == true end, }, @@ -543,7 +556,7 @@ function bytes.selftest() { name = "pad_to_16 - no padding needed", test = function() - local data = string.rep("a", 16) + local data = string_rep("a", 16) local padded = bytes.pad_to_16(data) return padded == data and #padded == 16 end, @@ -553,7 +566,7 @@ function bytes.selftest() test = function() local data = "Hello" local padded = bytes.pad_to_16(data) - return #padded == 16 and padded:sub(1, 5) == "Hello" and padded:sub(6) == string.rep("\0", 11) + return #padded == 16 and padded:sub(1, 5) == "Hello" and padded:sub(6) == string_rep("\0", 11) end, }, { @@ -567,7 +580,7 @@ function bytes.selftest() { name = "pad_to_16 - exactly 32 bytes", test = function() - local data = string.rep("a", 32) + local data = string_rep("a", 32) local padded = bytes.pad_to_16(data) return padded == data and #padded == 32 end, @@ -575,7 +588,7 @@ function bytes.selftest() { name = "pad_to_16 - one byte short", test = function() - local data = string.rep("a", 15) + local data = string_rep("a", 15) local padded = bytes.pad_to_16(data) return #padded == 16 and padded:sub(1, 15) == data and padded:sub(16) == "\0" end, @@ -583,15 +596,15 @@ function bytes.selftest() { name = "pad_to_16 - one byte over", test = function() - local data = string.rep("a", 17) + local data = string_rep("a", 17) local padded = bytes.pad_to_16(data) - return #padded == 32 and padded:sub(1, 17) == data and padded:sub(18) == string.rep("\0", 15) + return #padded == 32 and padded:sub(1, 17) == data and padded:sub(18) == string_rep("\0", 15) end, }, { name = "pad_to_16 - large data", test = function() - local data = string.rep("a", 1000) + local data = string_rep("a", 1000) local padded = bytes.pad_to_16(data) local expected_len = math.ceil(1000 / 16) * 16 return #padded == expected_len and padded:sub(1, 1000) == data @@ -660,7 +673,7 @@ function bytes.selftest() end end - print(string.format("\nByte operations result: %d/%d tests passed\n", passed, total)) + print(string_format("\nByte operations result: %d/%d tests passed\n", passed, total)) return passed == total end diff --git a/vendor/bitn.lua b/vendor/bitn.lua index f83484d..92c9cbe 100644 --- a/vendor/bitn.lua +++ b/vendor/bitn.lua @@ -100,6 +100,33 @@ if ok and result then return native_band(r, MASK32) end + -- Raw operations provide direct access to native bit functions without the + -- to_unsigned() wrapper. On Lua 5.3+, these are identical to wrapped versions + -- since native operators already return unsigned values. + -- Shifts must mask to 32 bits since native operators work on 64-bit values. + _compat.raw_band = native_band + _compat.raw_bor = native_bor + _compat.raw_bxor = native_bxor + _compat.raw_bnot = function(a) + return native_band(native_bnot(a), MASK32) + end + _compat.raw_lshift = function(a, n) + if n >= 32 then + return 0 + end + return native_band(native_lshift(a, n), MASK32) + end + _compat.raw_rshift = function(a, n) + if n >= 32 then + return 0 + end + return native_rshift(native_band(a, MASK32), n) + end + _compat.raw_arshift = _compat.arshift + -- No native rol/ror on Lua 5.3+ + _compat.raw_rol = nil + _compat.raw_ror = nil + return _compat end end @@ -221,6 +248,25 @@ if bit_lib then end end + -- Raw operations provide direct access to native bit functions without the + -- to_unsigned() wrapper. On LuaJIT, these return signed 32-bit integers. + -- On Lua 5.2 (bit32 library), these are identical to wrapped versions. + _compat.raw_band = bit_band + _compat.raw_bor = bit_bor + _compat.raw_bxor = bit_bxor + _compat.raw_bnot = bit_bnot + _compat.raw_lshift = bit_lshift + _compat.raw_rshift = bit_rshift + _compat.raw_arshift = bit_arshift + -- rol/ror only available on LuaJIT (bit library), not Lua 5.2 (bit32 library) + if bit_lib.rol then + _compat.raw_rol = bit_lib.rol + _compat.raw_ror = bit_lib.ror + else + _compat.raw_rol = nil + _compat.raw_ror = nil + end + return _compat end @@ -320,6 +366,18 @@ function _compat.arshift(a, n) return r end +-- Raw operations for pure Lua fallback are identical to wrapped versions +-- since there's no native library to bypass. +_compat.raw_band = _compat.band +_compat.raw_bor = _compat.bor +_compat.raw_bxor = _compat.bxor +_compat.raw_bnot = _compat.bnot +_compat.raw_lshift = _compat.lshift +_compat.raw_rshift = _compat.rshift +_compat.raw_arshift = _compat.arshift +_compat.raw_rol = nil +_compat.raw_ror = nil + return _compat end end @@ -339,18 +397,17 @@ local _compat = require("bitn._compat") -- Cache methods as locals for faster access local compat_band = _compat.band +local compat_bnot = _compat.bnot local compat_bor = _compat.bor local compat_bxor = _compat.bxor -local compat_bnot = _compat.bnot local compat_lshift = _compat.lshift local compat_rshift = _compat.rshift local impl_name = _compat.impl_name +local math_floor = math.floor -- 16-bit mask constant local MASK16 = 0xFFFF -local math_floor = math.floor - -------------------------------------------------------------------------------- -- Core operations -------------------------------------------------------------------------------- @@ -838,24 +895,42 @@ local bit32 = {} local _compat = require("bitn._compat") -- Cache methods as locals for faster access +local compat_arshift = _compat.arshift local compat_band = _compat.band +local compat_bnot = _compat.bnot local compat_bor = _compat.bor local compat_bxor = _compat.bxor -local compat_bnot = _compat.bnot local compat_lshift = _compat.lshift +local compat_raw_arshift = _compat.raw_arshift +local compat_raw_band = _compat.raw_band +local compat_raw_bnot = _compat.raw_bnot +local compat_raw_bor = _compat.raw_bor +local compat_raw_bxor = _compat.raw_bxor +local compat_raw_lshift = _compat.raw_lshift +local compat_raw_rol = _compat.raw_rol +local compat_raw_ror = _compat.raw_ror +local compat_raw_rshift = _compat.raw_rshift local compat_rshift = _compat.rshift -local compat_arshift = _compat.arshift +local compat_to_unsigned = _compat.to_unsigned local impl_name = _compat.impl_name +local math_floor = math.floor -- 32-bit mask constant local MASK32 = 0xFFFFFFFF -local math_floor = math.floor - -------------------------------------------------------------------------------- -- Core operations -------------------------------------------------------------------------------- +--- Convert signed 32-bit value to unsigned. +--- On LuaJIT, bit operations return signed 32-bit integers. This function +--- converts them to unsigned by adding 2^32 to negative values. +--- @param n number Potentially signed 32-bit value +--- @return integer result Unsigned 32-bit value (0 to 0xFFFFFFFF) +function bit32.to_unsigned(n) + return compat_to_unsigned(n) +end + --- Ensure value fits in 32-bit unsigned integer. --- @param n number Input value --- @return integer result 32-bit unsigned integer (0 to 0xFFFFFFFF) @@ -955,6 +1030,68 @@ function bit32.add(a, b) return compat_band(compat_band(a, MASK32) + compat_band(b, MASK32), MASK32) end +-------------------------------------------------------------------------------- +-- Raw (zero-overhead) operations +-------------------------------------------------------------------------------- +-- These functions provide direct access to the underlying bit library without +-- unsigned conversion. On LuaJIT, results may be negative when the high bit +-- is set. The bit pattern is identical to the regular function. +-- Use for performance-critical code where sign interpretation doesn't matter. + +--- Raw bitwise AND (may return signed on LuaJIT). +--- @type fun(a: integer, b: integer): integer +--- @see bit32.band For guaranteed unsigned results +bit32.raw_band = compat_raw_band + +--- Raw bitwise OR (may return signed on LuaJIT). +--- @type fun(a: integer, b: integer): integer +--- @see bit32.bor For guaranteed unsigned results +bit32.raw_bor = compat_raw_bor + +--- Raw bitwise XOR (may return signed on LuaJIT). +--- @type fun(a: integer, b: integer): integer +--- @see bit32.bxor For guaranteed unsigned results +bit32.raw_bxor = compat_raw_bxor + +--- Raw bitwise NOT (may return signed on LuaJIT). +--- @type fun(a: integer): integer +--- @see bit32.bnot For guaranteed unsigned results +bit32.raw_bnot = compat_raw_bnot + +--- Raw left shift (may return signed on LuaJIT). +--- @type fun(a: integer, n: integer): integer +--- @see bit32.lshift For guaranteed unsigned results +bit32.raw_lshift = compat_raw_lshift + +--- Raw logical right shift (may return signed on LuaJIT). +--- @type fun(a: integer, n: integer): integer +--- @see bit32.rshift For guaranteed unsigned results +bit32.raw_rshift = compat_raw_rshift + +--- Raw arithmetic right shift (may return signed on LuaJIT). +--- @type fun(a: integer, n: integer): integer +--- @see bit32.arshift For guaranteed unsigned results +bit32.raw_arshift = compat_raw_arshift + +--- Raw left rotate (uses native bit.rol on LuaJIT, falls back to computed otherwise). +--- @type fun(x: integer, n: integer): integer +--- @see bit32.rol For guaranteed unsigned results +bit32.raw_rol = compat_raw_rol or bit32.rol + +--- Raw right rotate (uses native bit.ror on LuaJIT, falls back to computed otherwise). +--- @type fun(x: integer, n: integer): integer +--- @see bit32.ror For guaranteed unsigned results +bit32.raw_ror = compat_raw_ror or bit32.ror + +--- Raw 32-bit addition with overflow handling. +--- @param a integer First operand (32-bit) +--- @param b integer Second operand (32-bit) +--- @return integer result Result of (a + b) mod 2^32 (signed on LuaJIT, unsigned elsewhere) +--- @see bit32.add For guaranteed unsigned results +function bit32.raw_add(a, b) + return compat_raw_band(a + b, MASK32) +end + -------------------------------------------------------------------------------- -- Byte conversion functions -------------------------------------------------------------------------------- @@ -967,7 +1104,12 @@ local string_byte = string.byte --- @return string bytes 4-byte string in big-endian order function bit32.u32_to_be_bytes(n) n = compat_band(n, MASK32) - return string_char(math_floor(n / 16777216) % 256, math_floor(n / 65536) % 256, math_floor(n / 256) % 256, n % 256) + return string_char( + math_floor(n / 16777216) % 256, + math_floor(n / 65536) % 256, + math_floor(n / 256) % 256, + math_floor(n % 256) + ) end --- Convert 32-bit unsigned integer to 4 bytes (little-endian). @@ -975,7 +1117,12 @@ end --- @return string bytes 4-byte string in little-endian order function bit32.u32_to_le_bytes(n) n = compat_band(n, MASK32) - return string_char(n % 256, math_floor(n / 256) % 256, math_floor(n / 65536) % 256, math_floor(n / 16777216) % 256) + return string_char( + math_floor(n % 256), + math_floor(n / 256) % 256, + math_floor(n / 65536) % 256, + math_floor(n / 16777216) % 256 + ) end --- Convert 4 bytes to 32-bit unsigned integer (big-endian). @@ -1025,6 +1172,14 @@ function bit32.selftest() { name = "mask(-1)", fn = bit32.mask, inputs = { -1 }, expected = 0xFFFFFFFF }, { name = "mask(-256)", fn = bit32.mask, inputs = { -256 }, expected = 0xFFFFFF00 }, + -- to_unsigned tests + { name = "to_unsigned(0)", fn = bit32.to_unsigned, inputs = { 0 }, expected = 0 }, + { name = "to_unsigned(1)", fn = bit32.to_unsigned, inputs = { 1 }, expected = 1 }, + { name = "to_unsigned(0x7FFFFFFF)", fn = bit32.to_unsigned, inputs = { 0x7FFFFFFF }, expected = 0x7FFFFFFF }, + { name = "to_unsigned(-1)", fn = bit32.to_unsigned, inputs = { -1 }, expected = 0xFFFFFFFF }, + { name = "to_unsigned(-2147483648)", fn = bit32.to_unsigned, inputs = { -2147483648 }, expected = 0x80000000 }, + { name = "to_unsigned(-2147483647)", fn = bit32.to_unsigned, inputs = { -2147483647 }, expected = 0x80000001 }, + -- band tests { name = "band(0xFF00FF00, 0x00FF00FF)", fn = bit32.band, inputs = { 0xFF00FF00, 0x00FF00FF }, expected = 0 }, { @@ -1138,25 +1293,25 @@ function bit32.selftest() name = "u32_to_be_bytes(0)", fn = bit32.u32_to_be_bytes, inputs = { 0 }, - expected = string.char(0x00, 0x00, 0x00, 0x00), + expected = string_char(0x00, 0x00, 0x00, 0x00), }, { name = "u32_to_be_bytes(1)", fn = bit32.u32_to_be_bytes, inputs = { 1 }, - expected = string.char(0x00, 0x00, 0x00, 0x01), + expected = string_char(0x00, 0x00, 0x00, 0x01), }, { name = "u32_to_be_bytes(0x12345678)", fn = bit32.u32_to_be_bytes, inputs = { 0x12345678 }, - expected = string.char(0x12, 0x34, 0x56, 0x78), + expected = string_char(0x12, 0x34, 0x56, 0x78), }, { name = "u32_to_be_bytes(0xFFFFFFFF)", fn = bit32.u32_to_be_bytes, inputs = { 0xFFFFFFFF }, - expected = string.char(0xFF, 0xFF, 0xFF, 0xFF), + expected = string_char(0xFF, 0xFF, 0xFF, 0xFF), }, -- u32_to_le_bytes tests @@ -1164,50 +1319,50 @@ function bit32.selftest() name = "u32_to_le_bytes(0)", fn = bit32.u32_to_le_bytes, inputs = { 0 }, - expected = string.char(0x00, 0x00, 0x00, 0x00), + expected = string_char(0x00, 0x00, 0x00, 0x00), }, { name = "u32_to_le_bytes(1)", fn = bit32.u32_to_le_bytes, inputs = { 1 }, - expected = string.char(0x01, 0x00, 0x00, 0x00), + expected = string_char(0x01, 0x00, 0x00, 0x00), }, { name = "u32_to_le_bytes(0x12345678)", fn = bit32.u32_to_le_bytes, inputs = { 0x12345678 }, - expected = string.char(0x78, 0x56, 0x34, 0x12), + expected = string_char(0x78, 0x56, 0x34, 0x12), }, { name = "u32_to_le_bytes(0xFFFFFFFF)", fn = bit32.u32_to_le_bytes, inputs = { 0xFFFFFFFF }, - expected = string.char(0xFF, 0xFF, 0xFF, 0xFF), + expected = string_char(0xFF, 0xFF, 0xFF, 0xFF), }, -- be_bytes_to_u32 tests { name = "be_bytes_to_u32(0x00000000)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0x00, 0x00, 0x00, 0x00) }, + inputs = { string_char(0x00, 0x00, 0x00, 0x00) }, expected = 0, }, { name = "be_bytes_to_u32(0x00000001)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0x00, 0x00, 0x00, 0x01) }, + inputs = { string_char(0x00, 0x00, 0x00, 0x01) }, expected = 1, }, { name = "be_bytes_to_u32(0x12345678)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0x12, 0x34, 0x56, 0x78) }, + inputs = { string_char(0x12, 0x34, 0x56, 0x78) }, expected = 0x12345678, }, { name = "be_bytes_to_u32(0xFFFFFFFF)", fn = bit32.be_bytes_to_u32, - inputs = { string.char(0xFF, 0xFF, 0xFF, 0xFF) }, + inputs = { string_char(0xFF, 0xFF, 0xFF, 0xFF) }, expected = 0xFFFFFFFF, }, @@ -1215,25 +1370,25 @@ function bit32.selftest() { name = "le_bytes_to_u32(0x00000000)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0x00, 0x00, 0x00, 0x00) }, + inputs = { string_char(0x00, 0x00, 0x00, 0x00) }, expected = 0, }, { name = "le_bytes_to_u32(0x00000001)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0x01, 0x00, 0x00, 0x00) }, + inputs = { string_char(0x01, 0x00, 0x00, 0x00) }, expected = 1, }, { name = "le_bytes_to_u32(0x12345678)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0x78, 0x56, 0x34, 0x12) }, + inputs = { string_char(0x78, 0x56, 0x34, 0x12) }, expected = 0x12345678, }, { name = "le_bytes_to_u32(0xFFFFFFFF)", fn = bit32.le_bytes_to_u32, - inputs = { string.char(0xFF, 0xFF, 0xFF, 0xFF) }, + inputs = { string_char(0xFF, 0xFF, 0xFF, 0xFF) }, expected = 0xFFFFFFFF, }, } @@ -1249,10 +1404,10 @@ function bit32.selftest() if type(test.expected) == "string" then local exp_hex, got_hex = "", "" for i = 1, #test.expected do - exp_hex = exp_hex .. string.format("%02X", string.byte(test.expected, i)) + exp_hex = exp_hex .. string.format("%02X", string_byte(test.expected, i)) end for i = 1, #result do - got_hex = got_hex .. string.format("%02X", string.byte(result, i)) + got_hex = got_hex .. string.format("%02X", string_byte(result, i)) end print(" Expected: " .. exp_hex) print(" Got: " .. got_hex) @@ -1263,6 +1418,191 @@ function bit32.selftest() end end + -- Test raw_* operations + print("\n Testing raw_* operations...") + + local raw_tests = { + -- Core bitwise (test high-bit cases where sign matters) + { + name = "raw_band(0xFFFFFFFF, 0x80000000)", + fn = function() + return bit32.to_unsigned(bit32.raw_band(0xFFFFFFFF, 0x80000000)) + end, + expected = bit32.band(0xFFFFFFFF, 0x80000000), + }, + { + name = "raw_bor(0x80000000, 0x00000001)", + fn = function() + return bit32.to_unsigned(bit32.raw_bor(0x80000000, 0x00000001)) + end, + expected = bit32.bor(0x80000000, 0x00000001), + }, + { + name = "raw_bxor(0xAAAAAAAA, 0x55555555)", + fn = function() + return bit32.to_unsigned(bit32.raw_bxor(0xAAAAAAAA, 0x55555555)) + end, + expected = bit32.bxor(0xAAAAAAAA, 0x55555555), + }, + { + name = "raw_bnot(0)", + fn = function() + return bit32.to_unsigned(bit32.raw_bnot(0)) + end, + expected = bit32.bnot(0), + }, + { + name = "raw_bnot(0x80000000)", + fn = function() + return bit32.to_unsigned(bit32.raw_bnot(0x80000000)) + end, + expected = bit32.bnot(0x80000000), + }, + + -- Shifts + { + name = "raw_lshift(1, 31)", + fn = function() + return bit32.to_unsigned(bit32.raw_lshift(1, 31)) + end, + expected = bit32.lshift(1, 31), + }, + { + name = "raw_rshift(0x80000000, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_rshift(0x80000000, 1)) + end, + expected = bit32.rshift(0x80000000, 1), + }, + { + name = "raw_arshift(0x80000000, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_arshift(0x80000000, 1)) + end, + expected = bit32.arshift(0x80000000, 1), + }, + + -- Shift masking (ensure 32-bit semantics on all platforms) + -- Note: n >= 32 behavior is platform-specific for raw shifts; callers should use n in 0-31 + { + name = "raw_lshift(0x12345678, 16) masks to 32 bits", + fn = function() + return bit32.to_unsigned(bit32.raw_lshift(0x12345678, 16)) + end, + expected = 0x56780000, + }, + { + name = "raw_rshift(0xFFFFFFFF, 16) masks to 32 bits", + fn = function() + return bit32.to_unsigned(bit32.raw_rshift(0xFFFFFFFF, 16)) + end, + expected = 0x0000FFFF, + }, + + -- Addition overflow + { + name = "raw_add(0xFFFFFFFF, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_add(0xFFFFFFFF, 1)) + end, + expected = bit32.add(0xFFFFFFFF, 1), + }, + { + name = "raw_add(0x80000000, 0x80000000)", + fn = function() + return bit32.to_unsigned(bit32.raw_add(0x80000000, 0x80000000)) + end, + expected = bit32.add(0x80000000, 0x80000000), + }, + } + + for _, test in ipairs(raw_tests) do + total = total + 1 + local result = test.fn() + if result == test.expected then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name) + print(string.format(" Expected: 0x%08X", test.expected)) + print(string.format(" Got: 0x%08X", result)) + end + end + + -- Test raw_rol/raw_ror (always available - falls back to computed if no native) + print("\n Testing raw_rol/raw_ror...") + local rol_ror_tests = { + { + name = "raw_rol(0x80000000, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_rol(0x80000000, 1)) + end, + expected = bit32.rol(0x80000000, 1), + }, + { + name = "raw_rol(0x12345678, 8)", + fn = function() + return bit32.to_unsigned(bit32.raw_rol(0x12345678, 8)) + end, + expected = bit32.rol(0x12345678, 8), + }, + { + name = "raw_ror(1, 1)", + fn = function() + return bit32.to_unsigned(bit32.raw_ror(1, 1)) + end, + expected = bit32.ror(1, 1), + }, + { + name = "raw_ror(0x12345678, 8)", + fn = function() + return bit32.to_unsigned(bit32.raw_ror(0x12345678, 8)) + end, + expected = bit32.ror(0x12345678, 8), + }, + } + + for _, test in ipairs(rol_ror_tests) do + total = total + 1 + local result = test.fn() + if result == test.expected then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name) + print(string.format(" Expected: 0x%08X", test.expected)) + print(string.format(" Got: 0x%08X", result)) + end + end + + -- Test zero-overhead on LuaJIT (identity check) + if _compat.is_luajit then + print("\n Testing zero-overhead (LuaJIT function identity)...") + local bit = require("bit") + + local identity_tests = { + { name = "raw_band == bit.band", got = bit32.raw_band, expected = bit.band }, + { name = "raw_bor == bit.bor", got = bit32.raw_bor, expected = bit.bor }, + { name = "raw_bxor == bit.bxor", got = bit32.raw_bxor, expected = bit.bxor }, + { name = "raw_bnot == bit.bnot", got = bit32.raw_bnot, expected = bit.bnot }, + { name = "raw_lshift == bit.lshift", got = bit32.raw_lshift, expected = bit.lshift }, + { name = "raw_rshift == bit.rshift", got = bit32.raw_rshift, expected = bit.rshift }, + { name = "raw_arshift == bit.arshift", got = bit32.raw_arshift, expected = bit.arshift }, + { name = "raw_rol == bit.rol", got = bit32.raw_rol, expected = bit.rol }, + { name = "raw_ror == bit.ror", got = bit32.raw_ror, expected = bit.ror }, + } + + for _, test in ipairs(identity_tests) do + total = total + 1 + if rawequal(test.got, test.expected) then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name .. " (not identical function reference)") + end + end + end + print(string.format("\n32-bit operations: %d/%d tests passed\n", passed, total)) return passed == total end @@ -1373,20 +1713,27 @@ local bit64 = {} local bit32 = require("bitn.bit32") local _compat = require("bitn._compat") -local impl_name = _compat.impl_name --- Cache bit32 methods as locals for faster access +-- Cache methods as locals for faster access +local bit32_arshift = bit32.arshift local bit32_band = bit32.band +local bit32_be_bytes_to_u32 = bit32.be_bytes_to_u32 +local bit32_bnot = bit32.bnot local bit32_bor = bit32.bor local bit32_bxor = bit32.bxor -local bit32_bnot = bit32.bnot +local bit32_le_bytes_to_u32 = bit32.le_bytes_to_u32 local bit32_lshift = bit32.lshift +local bit32_raw_arshift = bit32.raw_arshift +local bit32_raw_band = bit32.raw_band +local bit32_raw_bnot = bit32.raw_bnot +local bit32_raw_bor = bit32.raw_bor +local bit32_raw_bxor = bit32.raw_bxor +local bit32_raw_lshift = bit32.raw_lshift +local bit32_raw_rshift = bit32.raw_rshift local bit32_rshift = bit32.rshift -local bit32_arshift = bit32.arshift local bit32_u32_to_be_bytes = bit32.u32_to_be_bytes local bit32_u32_to_le_bytes = bit32.u32_to_le_bytes -local bit32_be_bytes_to_u32 = bit32.be_bytes_to_u32 -local bit32_le_bytes_to_u32 = bit32.le_bytes_to_u32 +local impl_name = _compat.impl_name -- Private metatable for Int64 type identification local Int64Meta = { __name = "Int64" } @@ -1399,11 +1746,21 @@ local Int64Meta = { __name = "Int64" } -------------------------------------------------------------------------------- --- Create a new Int64 value with metatable marker. +--- Normalizes signed 32-bit values to unsigned (for LuaJIT raw_* compatibility). --- @param high? integer Upper 32 bits (default: 0) --- @param low? integer Lower 32 bits (default: 0) --- @return Int64HighLow value Int64 value with metatable marker function bit64.new(high, low) - return setmetatable({ high or 0, low or 0 }, Int64Meta) + high = high or 0 + low = low or 0 + -- Normalize signed to unsigned (handles LuaJIT raw_* results) + if high < 0 then + high = high + 0x100000000 + end + if low < 0 then + low = low + 0x100000000 + end + return setmetatable({ high, low }, Int64Meta) end --- Check if a value is an Int64 (created by bit64 functions). @@ -1737,6 +2094,178 @@ bit64.asr = bit64.arshift --- Alias for is_int64 (compatibility with older API). bit64.isInt64 = bit64.is_int64 +-------------------------------------------------------------------------------- +-- Raw (zero-overhead) operations +-------------------------------------------------------------------------------- +-- These functions use bit32.raw_* internally for performance-critical code. +-- On LuaJIT, the internal 32-bit values may be signed, but bit patterns are correct. +-- Use for crypto code and tight loops where sign interpretation doesn't matter. + +--- Raw bitwise AND (uses bit32.raw_band internally). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} AND result +function bit64.raw_band(a, b) + return bit64.new(bit32_raw_band(a[1], b[1]), bit32_raw_band(a[2], b[2])) +end + +--- Raw bitwise OR (uses bit32.raw_bor internally). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} OR result +function bit64.raw_bor(a, b) + return bit64.new(bit32_raw_bor(a[1], b[1]), bit32_raw_bor(a[2], b[2])) +end + +--- Raw bitwise XOR (uses bit32.raw_bxor internally). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} XOR result +function bit64.raw_bxor(a, b) + return bit64.new(bit32_raw_bxor(a[1], b[1]), bit32_raw_bxor(a[2], b[2])) +end + +--- Raw bitwise NOT (uses bit32.raw_bnot internally). +--- @param a Int64HighLow Operand {high, low} +--- @return Int64HighLow result {high, low} NOT result +function bit64.raw_bnot(a) + return bit64.new(bit32_raw_bnot(a[1]), bit32_raw_bnot(a[2])) +end + +--- Raw left shift (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to shift {high, low} +--- @param n integer Number of positions to shift (must be >= 0) +--- @return Int64HighLow result {high, low} shifted value +function bit64.raw_lshift(x, n) + if n == 0 then + return bit64.new(x[1], x[2]) + elseif n >= 64 then + return bit64.new(0, 0) + elseif n >= 32 then + return bit64.new(bit32_raw_lshift(x[2], n - 32), 0) + else + local new_high = bit32_raw_bor(bit32_raw_lshift(x[1], n), bit32_raw_rshift(x[2], 32 - n)) + local new_low = bit32_raw_lshift(x[2], n) + return bit64.new(new_high, new_low) + end +end + +--- Raw logical right shift (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to shift {high, low} +--- @param n integer Number of positions to shift (must be >= 0) +--- @return Int64HighLow result {high, low} shifted value +function bit64.raw_rshift(x, n) + if n == 0 then + return bit64.new(x[1], x[2]) + elseif n >= 64 then + return bit64.new(0, 0) + elseif n >= 32 then + return bit64.new(0, bit32_raw_rshift(x[1], n - 32)) + else + local new_low = bit32_raw_bor(bit32_raw_rshift(x[2], n), bit32_raw_lshift(x[1], 32 - n)) + local new_high = bit32_raw_rshift(x[1], n) + return bit64.new(new_high, new_low) + end +end + +--- Raw arithmetic right shift (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to shift {high, low} +--- @param n integer Number of positions to shift (must be >= 0) +--- @return Int64HighLow result {high, low} shifted value +function bit64.raw_arshift(x, n) + if n == 0 then + return bit64.new(x[1], x[2]) + end + + local is_negative = bit32_raw_band(x[1], 0x80000000) ~= 0 + + if n >= 64 then + if is_negative then + return bit64.new(0xFFFFFFFF, 0xFFFFFFFF) + else + return bit64.new(0, 0) + end + elseif n >= 32 then + local new_low = bit32_raw_arshift(x[1], n - 32) + local new_high = is_negative and 0xFFFFFFFF or 0 + return bit64.new(new_high, new_low) + else + local new_low = bit32_raw_bor(bit32_raw_rshift(x[2], n), bit32_raw_lshift(x[1], 32 - n)) + local new_high = bit32_raw_arshift(x[1], n) + return bit64.new(new_high, new_low) + end +end + +--- Raw left rotate (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to rotate {high, low} +--- @param n integer Number of positions to rotate +--- @return Int64HighLow result {high, low} rotated value +function bit64.raw_rol(x, n) + n = n % 64 + if n == 0 then + return bit64.new(x[1], x[2]) + end + + local high, low = x[1], x[2] + + if n == 32 then + return bit64.new(low, high) + elseif n < 32 then + local new_high = bit32_raw_bor(bit32_raw_lshift(high, n), bit32_raw_rshift(low, 32 - n)) + local new_low = bit32_raw_bor(bit32_raw_lshift(low, n), bit32_raw_rshift(high, 32 - n)) + return bit64.new(new_high, new_low) + else + n = n - 32 + local new_high = bit32_raw_bor(bit32_raw_lshift(low, n), bit32_raw_rshift(high, 32 - n)) + local new_low = bit32_raw_bor(bit32_raw_lshift(high, n), bit32_raw_rshift(low, 32 - n)) + return bit64.new(new_high, new_low) + end +end + +--- Raw right rotate (uses bit32.raw_* internally). +--- @param x Int64HighLow Value to rotate {high, low} +--- @param n integer Number of positions to rotate +--- @return Int64HighLow result {high, low} rotated value +function bit64.raw_ror(x, n) + n = n % 64 + if n == 0 then + return bit64.new(x[1], x[2]) + end + + local high, low = x[1], x[2] + + if n == 32 then + return bit64.new(low, high) + elseif n < 32 then + local new_low = bit32_raw_bor(bit32_raw_rshift(low, n), bit32_raw_lshift(high, 32 - n)) + local new_high = bit32_raw_bor(bit32_raw_rshift(high, n), bit32_raw_lshift(low, 32 - n)) + return bit64.new(new_high, new_low) + else + n = n - 32 + local new_low = bit32_raw_bor(bit32_raw_rshift(high, n), bit32_raw_lshift(low, 32 - n)) + local new_high = bit32_raw_bor(bit32_raw_rshift(low, n), bit32_raw_lshift(high, 32 - n)) + return bit64.new(new_high, new_low) + end +end + +--- Raw 64-bit addition (uses bit32.raw_band for masking). +--- @param a Int64HighLow First operand {high, low} +--- @param b Int64HighLow Second operand {high, low} +--- @return Int64HighLow result {high, low} sum +function bit64.raw_add(a, b) + local low = a[2] + b[2] + local high = a[1] + b[1] + + if low >= 0x100000000 then + high = high + 1 + low = low % 0x100000000 + end + + high = high % 0x100000000 + + return bit64.new(high, low) +end + -------------------------------------------------------------------------------- -- Self-test -------------------------------------------------------------------------------- @@ -2185,6 +2714,18 @@ function bit64.selftest() print(" FAIL: new() with no args creates {0, 0}") end + -- Test bit64.new() normalizes negative values (LuaJIT raw_* compatibility) + total = total + 1 + local neg_val = bit64.new(-1, -2147483648) -- -1 -> 0xFFFFFFFF, -2147483648 -> 0x80000000 + if bit64.is_int64(neg_val) and neg_val[1] == 0xFFFFFFFF and neg_val[2] == 0x80000000 then + print(" PASS: new() normalizes negative values to unsigned") + passed = passed + 1 + else + print(" FAIL: new() normalizes negative values to unsigned") + print(string.format(" Expected: {0x%08X, 0x%08X}", 0xFFFFFFFF, 0x80000000)) + print(string.format(" Got: {0x%08X, 0x%08X}", neg_val[1], neg_val[2])) + end + -- Test is_int64() returns false for regular tables total = total + 1 local plain_table = { 0x12345678, 0x9ABCDEF0 } @@ -2209,61 +2750,61 @@ function bit64.selftest() { name = "band", fn = function() - return bit64.band({ 1, 2 }, { 3, 4 }) + return bit64.band(bit64.new(1, 2), bit64.new(3, 4)) end, }, { name = "bor", fn = function() - return bit64.bor({ 1, 2 }, { 3, 4 }) + return bit64.bor(bit64.new(1, 2), bit64.new(3, 4)) end, }, { name = "bxor", fn = function() - return bit64.bxor({ 1, 2 }, { 3, 4 }) + return bit64.bxor(bit64.new(1, 2), bit64.new(3, 4)) end, }, { name = "bnot", fn = function() - return bit64.bnot({ 1, 2 }) + return bit64.bnot(bit64.new(1, 2)) end, }, { name = "lshift", fn = function() - return bit64.lshift({ 1, 2 }, 1) + return bit64.lshift(bit64.new(1, 2), 1) end, }, { name = "rshift", fn = function() - return bit64.rshift({ 1, 2 }, 1) + return bit64.rshift(bit64.new(1, 2), 1) end, }, { name = "arshift", fn = function() - return bit64.arshift({ 1, 2 }, 1) + return bit64.arshift(bit64.new(1, 2), 1) end, }, { name = "rol", fn = function() - return bit64.rol({ 1, 2 }, 1) + return bit64.rol(bit64.new(1, 2), 1) end, }, { name = "ror", fn = function() - return bit64.ror({ 1, 2 }, 1) + return bit64.ror(bit64.new(1, 2), 1) end, }, { name = "add", fn = function() - return bit64.add({ 1, 2 }, { 3, 4 }) + return bit64.add(bit64.new(1, 2), bit64.new(3, 4)) end, }, { @@ -2292,7 +2833,7 @@ function bit64.selftest() end -- Test to_number strict mode error case - print("\nRunning to_number strict mode tests...") + print("\nRunning to_number/from_number edge case tests...") total = total + 1 local ok, err = pcall(function() bit64.to_number(bit64.new(0x00200000, 0x00000000), true) -- 2^53, exceeds 53-bit @@ -2309,6 +2850,220 @@ function bit64.selftest() end end + -- Test to_number pass-through for number input + total = total + 1 + local num_input = 12345 + local num_result = bit64.to_number(num_input) + if num_result == num_input then + print(" PASS: to_number passes through number input unchanged") + passed = passed + 1 + else + print(" FAIL: to_number passes through number input unchanged") + print(" Expected: " .. tostring(num_input)) + print(" Got: " .. tostring(num_result)) + end + + -- Test to_number errors on plain table (non-Int64) + total = total + 1 + ok, err = pcall(function() + bit64.to_number({ 1, 2 }) -- plain table, not Int64 + end) + if not ok and type(err) == "string" and string.find(err, "not a valid Int64") then + print(" PASS: to_number errors on plain table (non-Int64)") + passed = passed + 1 + else + print(" FAIL: to_number errors on plain table (non-Int64)") + if ok then + print(" Expected error but got success") + else + print(" Expected 'not a valid Int64' error but got: " .. tostring(err)) + end + end + + -- Test from_number pass-through for Int64 input + total = total + 1 + local int64_input = bit64.new(0x12345678, 0x9ABCDEF0) + local int64_result = bit64.from_number(int64_input) + if rawequal(int64_result, int64_input) then + print(" PASS: from_number passes through Int64 input unchanged") + passed = passed + 1 + else + print(" FAIL: from_number passes through Int64 input unchanged") + print(" Expected same reference, got different object") + end + + -- Test raw_* operations + print("\n Testing raw_* operations...") + + local raw_tests = { + -- Core bitwise (test high-bit cases where sign matters) + { + name = "raw_band(new(0xFFFFFFFF, 0x80000000), new(0x80000000, 0xFFFFFFFF))", + fn = function() + return bit64.raw_band(bit64.new(0xFFFFFFFF, 0x80000000), bit64.new(0x80000000, 0xFFFFFFFF)) + end, + expected = bit64.new(0x80000000, 0x80000000), + }, + { + name = "raw_bor(new(0x80000000, 0), new(0, 0x80000000))", + fn = function() + return bit64.raw_bor(bit64.new(0x80000000, 0), bit64.new(0, 0x80000000)) + end, + expected = bit64.new(0x80000000, 0x80000000), + }, + { + name = "raw_bxor(new(0xAAAAAAAA, 0x55555555), new(0x55555555, 0xAAAAAAAA))", + fn = function() + return bit64.raw_bxor(bit64.new(0xAAAAAAAA, 0x55555555), bit64.new(0x55555555, 0xAAAAAAAA)) + end, + expected = bit64.new(0xFFFFFFFF, 0xFFFFFFFF), + }, + { + name = "raw_bnot(new(0, 0))", + fn = function() + return bit64.raw_bnot(bit64.new(0, 0)) + end, + expected = bit64.new(0xFFFFFFFF, 0xFFFFFFFF), + }, + + -- Shifts + { + name = "raw_lshift(new(0, 1), 63)", + fn = function() + return bit64.raw_lshift(bit64.new(0, 1), 63) + end, + expected = bit64.new(0x80000000, 0), + }, + { + name = "raw_rshift(new(0x80000000, 0), 63)", + fn = function() + return bit64.raw_rshift(bit64.new(0x80000000, 0), 63) + end, + expected = bit64.new(0, 1), + }, + { + name = "raw_arshift(new(0x80000000, 0), 32)", + fn = function() + return bit64.raw_arshift(bit64.new(0x80000000, 0), 32) + end, + expected = bit64.new(0xFFFFFFFF, 0x80000000), + }, + + -- Rotates + { + name = "raw_rol(new(0x12345678, 0x9ABCDEF0), 16)", + fn = function() + return bit64.raw_rol(bit64.new(0x12345678, 0x9ABCDEF0), 16) + end, + expected = bit64.new(0x56789ABC, 0xDEF01234), + }, + { + name = "raw_ror(new(0x12345678, 0x9ABCDEF0), 16)", + fn = function() + return bit64.raw_ror(bit64.new(0x12345678, 0x9ABCDEF0), 16) + end, + expected = bit64.new(0xDEF01234, 0x56789ABC), + }, + + -- Addition + { + name = "raw_add(new(0xFFFFFFFF, 0xFFFFFFFF), new(0, 1))", + fn = function() + return bit64.raw_add(bit64.new(0xFFFFFFFF, 0xFFFFFFFF), bit64.new(0, 1)) + end, + expected = bit64.new(0, 0), + }, + } + + for _, test in ipairs(raw_tests) do + total = total + 1 + local result = test.fn() + if eq64(result, test.expected) then + print(" PASS: " .. test.name) + passed = passed + 1 + else + print(" FAIL: " .. test.name) + print(" Expected: " .. fmt64(test.expected)) + print(" Got: " .. fmt64(result)) + end + end + + -- Test that raw_* operations return Int64 + print("\n Testing raw_* operations return Int64...") + local raw_ops_returning_int64 = { + { + name = "raw_band", + fn = function() + return bit64.raw_band(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + { + name = "raw_bor", + fn = function() + return bit64.raw_bor(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + { + name = "raw_bxor", + fn = function() + return bit64.raw_bxor(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + { + name = "raw_bnot", + fn = function() + return bit64.raw_bnot(bit64.new(1, 2)) + end, + }, + { + name = "raw_lshift", + fn = function() + return bit64.raw_lshift(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_rshift", + fn = function() + return bit64.raw_rshift(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_arshift", + fn = function() + return bit64.raw_arshift(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_rol", + fn = function() + return bit64.raw_rol(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_ror", + fn = function() + return bit64.raw_ror(bit64.new(1, 2), 1) + end, + }, + { + name = "raw_add", + fn = function() + return bit64.raw_add(bit64.new(1, 2), bit64.new(3, 4)) + end, + }, + } + + for _, op in ipairs(raw_ops_returning_int64) do + total = total + 1 + local result = op.fn() + if bit64.is_int64(result) then + print(" PASS: " .. op.name .. "() returns Int64") + passed = passed + 1 + else + print(" FAIL: " .. op.name .. "() returns Int64") + end + end + print(string.format("\n64-bit operations: %d/%d tests passed\n", passed, total)) return passed == total end @@ -2524,7 +3279,7 @@ local bitn = { } --- Library version (injected at build time for releases). -local VERSION = "v0.5.1" +local VERSION = "v0.6.0" --- Get the library version string. --- @return string version Version string (e.g., "v1.0.0" or "dev")