diff --git a/CLAUDE.md b/CLAUDE.md index 6633bd3..affbb76 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -90,20 +90,22 @@ src/noiseprotocol/ ├── crypto/ # Cryptographic primitives │ ├── init.lua # Crypto module aggregator │ ├── x25519.lua / x448.lua # Diffie-Hellman functions -│ ├── chacha20*.lua # Stream cipher and AEAD +│ ├── chacha20.lua # Stream cipher +│ ├── chacha20_poly1305.lua # ChaCha20-Poly1305 AEAD │ ├── aes_gcm.lua # AES-GCM AEAD │ ├── poly1305.lua # Poly1305 MAC -│ ├── sha*.lua / blake2.lua # Hash functions +│ ├── sha256.lua / sha512.lua / blake2.lua # Hash functions ├── utils/ # Utility modules -│ ├── bit32.lua / bit64.lua # Bitwise operations │ ├── bytes.lua # Byte manipulation utilities │ └── benchmark.lua # Performance measurement tools └── openssl_wrapper.lua # Optional OpenSSL acceleration +vendor/ +└── bitn.lua # Unified bitwise operations for all Lua versions ``` ### Key Classes and APIs -**NoiseConnection** (`src/noiseprotocol/init.lua:1551`) +**NoiseConnection** (`src/noiseprotocol/init.lua:1563`) - Main API for establishing secure connections - Handles handshake patterns (XX, IK, NK, etc.) and PSK variants - Manages transport phase encryption/decryption @@ -143,6 +145,7 @@ Supports all standard patterns from the Noise specification: - LuaJIT significantly outperforms standard Lua interpreters - Benchmarks should be run with LuaJIT for realistic performance data - X448 is notably slower than X25519 in pure Lua +- Crypto modules use pre-allocated arrays for performance; not thread-safe for concurrent coroutines ### Compatibility - Supports Lua 5.1, 5.2, 5.3, 5.4, and LuaJIT diff --git a/Makefile b/Makefile index 5dbca7c..57d3d26 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,9 @@ +# Luarocks path for amalg and other tools +LUAROCKS_PATH := $(shell luarocks path --lr-path 2>/dev/null) + +# Lua path for local modules (src, vendor) +LUA_PATH_LOCAL := ./?.lua;./?/init.lua;./src/?.lua;./src/?/init.lua;./vendor/?.lua;$(LUAROCKS_PATH) + # Default target .PHONY: all all: format lint test build @@ -37,7 +43,7 @@ build/amalg.cache: src/noiseprotocol/init.lua @echo "Generating amalgamation cache..." @mkdir -p build @if command -v amalg.lua >/dev/null 2>&1; then \ - LUA_PATH="./src/?.lua;./src/?/init.lua;$(LUA_PATH)" lua -lamalg src/noiseprotocol/init.lua && mv amalg.cache build || exit 1; \ + LUA_PATH="$(LUA_PATH_LOCAL)" lua -lamalg src/noiseprotocol/init.lua && mv amalg.cache build || exit 1; \ echo "Generated amalg.cache"; \ else \ echo "Error: amalg not found."; \ @@ -51,9 +57,9 @@ build/amalg.cache: src/noiseprotocol/init.lua build: build/amalg.cache @echo "Building single-file distribution..." @if command -v amalg.lua >/dev/null 2>&1; then \ - LUA_PATH="./src/?.lua;./src/?/init.lua;$(LUA_PATH)" amalg.lua -o build/noiseprotocol.lua -C ./build/amalg.cache || exit 1;\ + LUA_PATH="$(LUA_PATH_LOCAL)" amalg.lua -o build/noiseprotocol.lua -C ./build/amalg.cache || exit 1; \ echo "Built build/noiseprotocol.lua"; \ - LUA_PATH="./src/?.lua;./src/?/init.lua;$(LUA_PATH)" amalg.lua -o build/noiseprotocol-core.lua -C ./build/amalg.cache -i "vendor%." || exit 1;\ + LUA_PATH="$(LUA_PATH_LOCAL)" amalg.lua -o build/noiseprotocol-core.lua -C ./build/amalg.cache -i "bitn" || exit 1; \ echo "Built build/noiseprotocol-core.lua (no vendor dependencies)"; \ VERSION=$$(git describe --exact-match --tags 2>/dev/null || echo "dev"); \ if [ "$$VERSION" != "dev" ]; then \ @@ -120,7 +126,7 @@ format: .PHONY: format-check format-check: @if command -v stylua >/dev/null 2>&1; then \ - echo "Running stylua check..."; \ + echo "Running stylua check..."; \ stylua --check --indent-type Spaces --column-width 120 --line-endings Unix \ --indent-width 2 --quote-style AutoPreferDouble \ src/ tests/; \ @@ -155,24 +161,26 @@ help: @echo "Noise Protocol Framework - Makefile targets" @echo "" @echo "Testing:" - @echo " make test - Run all tests" - @echo " make test- - Run specific test (e.g., make test-x25519)" - @echo " make test-matrix - Run test matrix across Lua versions" + @echo " make test - Run all tests" + @echo " make test- - Run specific test (e.g., make test-x25519)" + @echo " make test-matrix - Run tests across all Lua versions" + @echo " make test-matrix- - Run specific test across all Lua versions" @echo "" @echo "Benchmarking:" - @echo " make bench - Run all benchmarks" - @echo " make bench- - Run specific benchmark (e.g., make bench-x25519)" + @echo " make bench - Run all benchmarks" + @echo " make bench- - Run specific benchmark (e.g., make bench-x25519)" @echo "" @echo "Building:" - @echo " make build - Build single-file distribution" + @echo " make build - Build single-file distributions" @echo "" @echo "Code Quality:" - @echo " make format - Format all code (Lua)" - @echo " make format-check - Check code formatting" - @echo " make lint - Lint code with luacheck" + @echo " make check - Run format-check and lint" + @echo " make format - Format code with stylua" + @echo " make format-check - Check code formatting" + @echo " make lint - Lint code with luacheck" @echo "" @echo "Setup:" - @echo " make install-deps - Install all development dependencies" - @echo " make clean - Remove generated files" + @echo " make install-deps - Install development dependencies" + @echo " make clean - Remove generated files" @echo "" - @echo " make help - Show this help" + @echo " make help - Show this help" diff --git a/README.md b/README.md index 2143057..41416ea 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Download a pre-built single-file module from the - **`noiseprotocol.lua`** - Complete bundle with all dependencies included (zero external dependencies) -- **`noiseprotocol-core.lua`** - Core library only, requires `vendor.bitn` to be +- **`noiseprotocol-core.lua`** - Core library only, requires `bitn` to be installed separately ### Option 2: From Source @@ -92,7 +92,7 @@ print("Handshake complete!") -- Print first 16 bytes of handshake hash as hex local utils = require("noiseprotocol.utils") local hash = alice:get_handshake_hash() -print("Alice handshake hash:", bytes.to_hex(hash):sub(1, 32)) -- 32 hex chars = 16 bytes +print("Alice handshake hash:", utils.bytes.to_hex(hash):sub(1, 32)) -- 32 hex chars = 16 bytes -- Transport phase - send encrypted messages local ciphertext1 = alice:send_message("Hello Bob!") @@ -119,22 +119,57 @@ All one-way and interactive patterns from the Noise specification are supported: - **AEAD**: ChaChaPoly, AESGCM - **Hash**: SHA256, SHA512, BLAKE2s, BLAKE2b -## Testing +## Development -Run the test suite: +### Setup ```bash -# Run all tests with default Lua interpreter -./run_tests.sh +# Install development dependencies (stylua, luacheck, amalg) +make install-deps +``` + +### Testing + +```bash +make test # Run all tests +make test-chacha20 # Run specific module tests +make test-matrix # Run tests across all Lua versions +make test-matrix-x25519 # Run specific module across all Lua versions -# Run with specific Lua version +# Or use scripts directly with custom Lua binary LUA_BINARY=lua5.1 ./run_tests.sh +``` -# Run specific modules -./run_tests.sh chacha20 poly1305 +### Benchmarking -# Run test matrix across all Lua versions -./run_tests_matrix.sh +```bash +make bench # Run all benchmarks +make bench-x25519 # Run specific module benchmark + +# Or use scripts directly with custom Lua binary +LUA_BINARY=luajit ./run_benchmarks.sh +``` + +### Code Quality + +```bash +make check # Run format check and lint +make format # Format code with stylua +make format-check # Check formatting without modifying +make lint # Run luacheck +``` + +### Building + +```bash +make build # Build single-file distributions (build/noiseprotocol.lua, build/noiseprotocol-core.lua) +make clean # Remove generated files +``` + +### Help + +```bash +make help # Show all available targets ``` ## Current Limitations @@ -142,10 +177,8 @@ LUA_BINARY=lua5.1 ./run_tests.sh - Pure Lua performance is slower than native implementations - No constant-time guarantees (not suitable for production use without additional hardening) - -## Future Plans - -- Performance optimizations for the pure Lua implementation +- Not thread-safe for concurrent coroutines (uses pre-allocated arrays for + performance) ## Security Warning @@ -162,7 +195,7 @@ native cryptographic libraries. ## License -GNU Affero General Public License v3.0 - see LICENSE file for details +GNU Affero General Public License v3.0 - see LICENSE file for details. ## Contributing diff --git a/run_benchmarks.sh b/run_benchmarks.sh index af3657b..045d4f7 100755 --- a/run_benchmarks.sh +++ b/run_benchmarks.sh @@ -42,7 +42,7 @@ echo script_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) # Add repository root to Lua's package path -lua_path="$script_dir/?.lua;$script_dir/?/init.lua;$script_dir/src/?.lua;$script_dir/src/?/init.lua;$LUA_PATH" +lua_path="$script_dir/?.lua;$script_dir/?/init.lua;$script_dir/src/?.lua;$script_dir/src/?/init.lua;$script_dir/vendor/?.lua;$LUA_PATH" # Parse command line arguments to determine which modules to run default_modules=("aes_gcm" "blake2" "chacha20" "chacha20_poly1305" "poly1305" "sha256" "sha512" "x448" "x25519") diff --git a/src/noiseprotocol/crypto/aes_gcm.lua b/src/noiseprotocol/crypto/aes_gcm.lua index 199dccf..51ef88c 100644 --- a/src/noiseprotocol/crypto/aes_gcm.lua +++ b/src/noiseprotocol/crypto/aes_gcm.lua @@ -1,14 +1,27 @@ --- @module "noiseprotocol.crypto.aes_gcm" --- AES-GCM Authenticated Encryption with Associated Data (AEAD) Implementation for portability. +--- @class noiseprotocol.crypto.aes_gcm local aes_gcm = {} -local bit32 = require("vendor.bitn").bit32 +local bit32 = require("bitn").bit32 local openssl_wrapper = require("noiseprotocol.openssl_wrapper") local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op +-- Local references for performance (avoid module table lookups in hot loops) +local bit32_band = bit32.band +local bit32_bor = bit32.bor +local bit32_bxor = bit32.bxor +local bit32_lshift = bit32.lshift +local bit32_rshift = bit32.rshift +local string_byte = string.byte +local string_char = string.char +local string_rep = string.rep +local string_sub = string.sub +local table_concat = table.concat + -- ============================================================================ -- AES CORE IMPLEMENTATION -- ============================================================================ @@ -293,16 +306,63 @@ local RCON = { --- @alias AESGCMBlock [integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer] --- @alias AESGCMState [AESGCMWord, AESGCMWord, AESGCMWord, AESGCMWord] +--- Initialize a 16-element GCM block with zeros +--- @return AESGCMBlock block Initialized block +local function create_gcm_block() + local arr = {} + for i = 1, 16 do + arr[i] = 0 + end + --- @cast arr AESGCMBlock + return arr +end + +--- Initialize a 4x4 AES state array with zeros +--- @return AESGCMState state Initialized state +local function create_aes_state() + local state = {} + for i = 1, 4 do + state[i] = {} + for j = 1, 4 do + state[i][j] = 0 + end + end + --- @cast state AESGCMState + return state +end + +--- Initialize a 4-element AES word with zeros +--- @return AESGCMWord word Initialized word +local function create_aes_word() + local arr = {} + for i = 1, 4 do + arr[i] = 0 + end + --- @cast arr AESGCMWord + return arr +end + +-- Pre-allocated arrays for gcm_multiply() to avoid repeated allocation +local gcm_z = create_gcm_block() +local gcm_v = create_gcm_block() + +-- Pre-allocated state array for aes_encrypt_block() +local aes_state = create_aes_state() + +-- Pre-allocated arrays for mix_columns() +local mix_a = create_aes_word() +local mix_b = create_aes_word() + --- XOR two 4-byte words --- @param a AESGCMWord 4-byte array --- @param b AESGCMWord 4-byte array --- @return table Word 4-byte array local function xor_words(a, b) return { - bit32.bxor(a[1], b[1]), - bit32.bxor(a[2], b[2]), - bit32.bxor(a[3], b[3]), - bit32.bxor(a[4], b[4]), + bit32_bxor(a[1], b[1]), + bit32_bxor(a[2], b[2]), + bit32_bxor(a[3], b[3]), + bit32_bxor(a[4], b[4]), } end @@ -347,24 +407,25 @@ local function key_expansion(key) end -- Convert key to words - --- @type AESGCMWord + --- @type AESGCMState local w = {} - for i = 0, nk - 1 do + for i = 1, nk do w[i] = { - string.byte(key, i * 4 + 1), - string.byte(key, i * 4 + 2), - string.byte(key, i * 4 + 3), - string.byte(key, i * 4 + 4), + string_byte(key, (i - 1) * 4 + 1), + string_byte(key, (i - 1) * 4 + 2), + string_byte(key, (i - 1) * 4 + 3), + string_byte(key, (i - 1) * 4 + 4), } end -- Expand key - for i = nk, 4 * (nr + 1) - 1 do + for i = nk + 1, 4 * (nr + 1) do local temp = w[i - 1] - if i % nk == 0 then - local t = assert(RCON[i / nk], "Invalid RCON index " .. (i / nk)) + local idx = i - 1 -- 0-based index for modulo arithmetic + if idx % nk == 0 then + local t = assert(RCON[idx / nk], "Invalid RCON index " .. (idx / nk)) temp = xor_words(sub_word(rot_word(temp)), { t, 0, 0, 0 }) - elseif nk > 6 and i % nk == 4 then + elseif nk > 6 and idx % nk == 4 then temp = sub_word(temp) end w[i] = xor_words(w[i - nk], temp) @@ -376,29 +437,28 @@ end --- MixColumns transformation --- @param state AESGCMState 4x4 state matrix local function mix_columns(state) - for c = 0, 3 do - --- @type AESGCMWord - local a = {} - --- @type AESGCMWord - local b = {} - for i = 0, 3 do + -- Reuse pre-allocated arrays + local a = mix_a + local b = mix_b + for c = 1, 4 do + for i = 1, 4 do a[i] = state[i][c] - b[i] = bit32.band(state[i][c], 0x80) ~= 0 and bit32.bxor(bit32.band(bit32.lshift(state[i][c], 1), 0xFF), 0x1B) - or bit32.band(bit32.lshift(state[i][c], 1), 0xFF) + b[i] = bit32_band(state[i][c], 0x80) ~= 0 and bit32_bxor(bit32_band(bit32_lshift(state[i][c], 1), 0xFF), 0x1B) + or bit32_band(bit32_lshift(state[i][c], 1), 0xFF) end - state[0][c] = bit32.bxor(bit32.bxor(bit32.bxor(b[0], a[1]), bit32.bxor(b[1], a[2])), a[3]) - state[1][c] = bit32.bxor(bit32.bxor(bit32.bxor(a[0], b[1]), bit32.bxor(a[2], b[2])), a[3]) - state[2][c] = bit32.bxor(bit32.bxor(bit32.bxor(a[0], a[1]), bit32.bxor(b[2], a[3])), b[3]) - state[3][c] = bit32.bxor(bit32.bxor(bit32.bxor(a[0], b[0]), bit32.bxor(a[1], a[2])), b[3]) + state[1][c] = bit32_bxor(bit32_bxor(bit32_bxor(b[1], a[2]), bit32_bxor(b[2], a[3])), a[4]) + state[2][c] = bit32_bxor(bit32_bxor(bit32_bxor(a[1], b[2]), bit32_bxor(a[3], b[3])), a[4]) + state[3][c] = bit32_bxor(bit32_bxor(bit32_bxor(a[1], a[2]), bit32_bxor(b[3], a[4])), b[4]) + state[4][c] = bit32_bxor(bit32_bxor(bit32_bxor(a[1], b[1]), bit32_bxor(a[2], a[3])), b[4]) end end --- SubBytes transformation --- @param state AESGCMState 4x4 state matrix local function sub_bytes(state) - for i = 0, 3 do - for j = 0, 3 do + for i = 1, 4 do + for j = 1, 4 do local s_index = state[i][j] + 1 state[i][j] = assert(SBOX[s_index], "Invalid SBOX index " .. s_index) end @@ -408,28 +468,28 @@ end --- ShiftRows transformation --- @param state AESGCMState 4x4 state matrix local function shift_rows(state) - -- Row 0: no shift - -- Row 1: shift left by 1 - local temp = state[1][0] - state[1][0] = state[1][1] - state[1][1] = state[1][2] - state[1][2] = state[1][3] - state[1][3] = temp - - -- Row 2: shift left by 2 - temp = state[2][0] - state[2][0] = state[2][2] - state[2][2] = temp - temp = state[2][1] - state[2][1] = state[2][3] - state[2][3] = temp - - -- Row 3: shift left by 3 (or right by 1) - temp = state[3][3] - state[3][3] = state[3][2] - state[3][2] = state[3][1] - state[3][1] = state[3][0] - state[3][0] = temp + -- Row 1: no shift + -- Row 2: shift left by 1 + local temp = state[2][1] + state[2][1] = state[2][2] + state[2][2] = state[2][3] + state[2][3] = state[2][4] + state[2][4] = temp + + -- Row 3: shift left by 2 + temp = state[3][1] + state[3][1] = state[3][3] + state[3][3] = temp + temp = state[3][2] + state[3][2] = state[3][4] + state[3][4] = temp + + -- Row 4: shift left by 3 (or right by 1) + temp = state[4][4] + state[4][4] = state[4][3] + state[4][3] = state[4][2] + state[4][2] = state[4][1] + state[4][1] = temp end --- AddRoundKey transformation @@ -437,10 +497,10 @@ end --- @param round_key table Round key words --- @param round integer Round number local function add_round_key(state, round_key, round) - for c = 0, 3 do + for c = 1, 4 do local key_word = round_key[round * 4 + c] - for r = 0, 3 do - state[r][c] = bit32.bxor(state[r][c], key_word[r + 1]) + for r = 1, 4 do + state[r][c] = bit32_bxor(state[r][c], key_word[r]) end end end @@ -451,14 +511,11 @@ end --- @param nr integer Number of rounds --- @return string ciphertext 16-byte encrypted block local function aes_encrypt_block(input, expanded_key, nr) - -- Initialize state from input - --- @type AESGCMState - local state = {} - for i = 0, 3 do - --- @type AESGCMWord - state[i] = {} - for j = 0, 3 do - state[i][j] = string.byte(input, j * 4 + i + 1) + -- Reuse pre-allocated state array + local state = aes_state + for i = 1, 4 do + for j = 1, 4 do + state[i][j] = string_byte(input, (j - 1) * 4 + i) end end @@ -481,14 +538,14 @@ local function aes_encrypt_block(input, expanded_key, nr) -- Convert state to output (optimized with table) local output_bytes = {} local idx = 1 - for j = 0, 3 do - for i = 0, 3 do - output_bytes[idx] = string.char(state[i][j]) + for j = 1, 4 do + for i = 1, 4 do + output_bytes[idx] = string_char(state[i][j]) idx = idx + 1 end end - return table.concat(output_bytes) + return table_concat(output_bytes) end -- ============================================================================ @@ -500,40 +557,40 @@ end --- @param y string 16-byte block --- @return string result Product in GF(2^128) local function gcm_multiply(x, y) - -- Convert to bit arrays for easier manipulation - --- @type AESGCMBlock - local z = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } - --- @type AESGCMBlock - local v = {} + -- Reuse pre-allocated arrays + local z = gcm_z + local v = gcm_v + -- Reset z and initialize v for i = 1, 16 do - v[i] = string.byte(y, i) + z[i] = 0 + v[i] = string_byte(y, i) end -- Process each bit of x from MSB to LSB for i = 1, 16 do - local byte = string.byte(x, i) + local byte = string_byte(x, i) for bit = 7, 0, -1 do - if bit32.band(byte, bit32.lshift(1, bit)) ~= 0 then + if bit32_band(byte, bit32_lshift(1, bit)) ~= 0 then -- z = z XOR v for j = 1, 16 do - z[j] = bit32.bxor(z[j], v[j]) + z[j] = bit32_bxor(z[j], v[j]) end end -- Check if LSB of v is 1 (bit 0 of last byte) - local lsb = bit32.band(v[16], 1) + local lsb = bit32_band(v[16], 1) -- v = v >> 1 (right shift entire 128-bit value by 1) local carry = 0 for j = 1, 16 do - local new_carry = bit32.band(v[j], 1) - v[j] = bit32.bor(bit32.rshift(v[j], 1), bit32.lshift(carry, 7)) + local new_carry = bit32_band(v[j], 1) + v[j] = bit32_bor(bit32_rshift(v[j], 1), bit32_lshift(carry, 7)) carry = new_carry end -- If LSB was 1, XOR with R = 0xE1000000000000000000000000000000 if lsb ~= 0 then - v[1] = bit32.bxor(v[1], 0xE1) + v[1] = bit32_bxor(v[1], 0xE1) end end end @@ -541,7 +598,7 @@ local function gcm_multiply(x, y) -- Convert result back to string local result = "" for i = 1, 16 do - result = result .. string.char(z[i]) + result = result .. string_char(z[i]) end return result end @@ -551,16 +608,16 @@ end --- @param data string Data to hash (multiple of 16 bytes) --- @return string result 16-byte hash local function ghash(h, data) - local y = string.rep("\0", 16) + local y = string_rep("\0", 16) -- Process each 16-byte block for i = 1, #data, 16 do - local block = string.sub(data, i, i + 15) + local block = string_sub(data, i, i + 15) -- y = (y XOR block) * h local y_xor = "" for j = 1, 16 do - y_xor = y_xor .. string.char(bit32.bxor(string.byte(y, j), string.byte(block, j))) + y_xor = y_xor .. string_char(bit32_bxor(string_byte(y, j), string_byte(block, j))) end y = gcm_multiply(y_xor, h) @@ -573,19 +630,19 @@ end --- @param counter string 16-byte counter block --- @return string result Incremented counter local function inc_counter(counter) - local result = string.sub(counter, 1, 12) -- Keep first 12 bytes + local result = string_sub(counter, 1, 12) -- Keep first 12 bytes -- Increment last 4 bytes (big-endian) local val = 0 for i = 13, 16 do - val = val * 256 + string.byte(counter, i) + val = val * 256 + string_byte(counter, i) end val = (val + 1) % 0x100000000 -- Convert back to bytes (big-endian) for i = 3, 0, -1 do - result = result .. string.char(bit32.band(bit32.rshift(val, i * 8), 0xFF)) + result = result .. string_char(bit32_band(bit32_rshift(val, i * 8), 0xFF)) end return result @@ -602,7 +659,7 @@ local function generate_keystream(key, iv, length) local total_length = 0 -- Initial counter value: IV || 0x00000002 - local counter = iv .. string.rep("\0", 3) .. string.char(0x02) + local counter = iv .. string_rep("\0", 3) .. string_char(0x02) while total_length < length do local block = aes_encrypt_block(counter, expanded_key, nr) @@ -611,8 +668,8 @@ local function generate_keystream(key, iv, length) counter = inc_counter(counter) end - local keystream = table.concat(keystream_blocks) - return string.sub(keystream, 1, length) + local keystream = table_concat(keystream_blocks) + return string_sub(keystream, 1, length) end -- ============================================================================ @@ -629,12 +686,12 @@ local function format_gcm_data(aad, ciphertext) -- Add AAD and padding result = result .. aad local aad_pad = (16 - (#aad % 16)) % 16 - result = result .. string.rep("\0", aad_pad) + result = result .. string_rep("\0", aad_pad) -- Add ciphertext and padding result = result .. ciphertext local ct_pad = (16 - (#ciphertext % 16)) % 16 - result = result .. string.rep("\0", ct_pad) + result = result .. string_rep("\0", ct_pad) -- Add lengths (in bits) as 64-bit big-endian integers -- For messages under 2^61 bytes, high 32 bits are always 0 @@ -642,11 +699,11 @@ local function format_gcm_data(aad, ciphertext) local ct_bits_low = #ciphertext * 8 -- AAD length (64 bits big-endian) - result = result .. string.rep("\0", 4) -- High 32 bits + result = result .. string_rep("\0", 4) -- High 32 bits result = result .. bytes.u32_to_be_bytes(aad_bits_low) -- Low 32 bits -- Ciphertext length (64 bits big-endian) - result = result .. string.rep("\0", 4) -- High 32 bits + result = result .. string_rep("\0", 4) -- High 32 bits result = result .. bytes.u32_to_be_bytes(ct_bits_low) -- Low 32 bits return result @@ -696,16 +753,16 @@ function aes_gcm.encrypt(key, nonce, plaintext, aad) local expanded_key, nr = key_expansion(key) -- Generate hash key H = E(K, 0^128) - local h = aes_encrypt_block(string.rep("\0", 16), expanded_key, nr) + local h = aes_encrypt_block(string_rep("\0", 16), expanded_key, nr) -- Initial counter: nonce || 0x00000001 - local j0 = nonce .. string.rep("\0", 3) .. string.char(0x01) + local j0 = nonce .. string_rep("\0", 3) .. string_char(0x01) -- Encrypt plaintext using CTR mode local keystream = generate_keystream(key, nonce, #plaintext) local ciphertext = "" for i = 1, #plaintext do - ciphertext = ciphertext .. string.char(bit32.bxor(string.byte(plaintext, i), string.byte(keystream, i))) + ciphertext = ciphertext .. string_char(bit32_bxor(string_byte(plaintext, i), string_byte(keystream, i))) end -- Calculate authentication tag @@ -716,7 +773,7 @@ function aes_gcm.encrypt(key, nonce, plaintext, aad) local encrypted_j0 = aes_encrypt_block(j0, expanded_key, nr) local tag = "" for i = 1, 16 do - tag = tag .. string.char(bit32.bxor(string.byte(s, i), string.byte(encrypted_j0, i))) + tag = tag .. string_char(bit32_bxor(string_byte(s, i), string_byte(encrypted_j0, i))) end return ciphertext .. tag @@ -741,8 +798,8 @@ function aes_gcm.decrypt(key, nonce, ciphertext_and_tag, aad) -- Split ciphertext and tag local ciphertext_len = #ciphertext_and_tag - 16 - local ciphertext = string.sub(ciphertext_and_tag, 1, ciphertext_len) - local received_tag = string.sub(ciphertext_and_tag, ciphertext_len + 1) + local ciphertext = string_sub(ciphertext_and_tag, 1, ciphertext_len) + local received_tag = string_sub(ciphertext_and_tag, ciphertext_len + 1) local openssl = openssl_wrapper.get(openssl_wrapper.Feature.AAD) if openssl then @@ -771,10 +828,10 @@ function aes_gcm.decrypt(key, nonce, ciphertext_and_tag, aad) local expanded_key, nr = key_expansion(key) -- Generate hash key H = E(K, 0^128) - local h = aes_encrypt_block(string.rep("\0", 16), expanded_key, nr) + local h = aes_encrypt_block(string_rep("\0", 16), expanded_key, nr) -- Initial counter: nonce || 0x00000001 - local j0 = nonce .. string.rep("\0", 3) .. string.char(0x01) + local j0 = nonce .. string_rep("\0", 3) .. string_char(0x01) -- Calculate expected authentication tag local gcm_data = format_gcm_data(aad, ciphertext) @@ -784,7 +841,7 @@ function aes_gcm.decrypt(key, nonce, ciphertext_and_tag, aad) local encrypted_j0 = aes_encrypt_block(j0, expanded_key, nr) local expected_tag = "" for i = 1, 16 do - expected_tag = expected_tag .. string.char(bit32.bxor(string.byte(s, i), string.byte(encrypted_j0, i))) + expected_tag = expected_tag .. string_char(bit32_bxor(string_byte(s, i), string_byte(encrypted_j0, i))) end -- Verify tag (constant-time comparison) @@ -796,7 +853,7 @@ function aes_gcm.decrypt(key, nonce, ciphertext_and_tag, aad) local keystream = generate_keystream(key, nonce, #ciphertext) local plaintext = "" for i = 1, #ciphertext do - plaintext = plaintext .. string.char(bit32.bxor(string.byte(ciphertext, i), string.byte(keystream, i))) + plaintext = plaintext .. string_char(bit32_bxor(string_byte(ciphertext, i), string_byte(keystream, i))) end return plaintext @@ -806,8 +863,8 @@ end local test_vectors = { { name = "NIST Test Case 1 (AES-128-GCM)", - key = string.rep("\0", 16), - nonce = string.rep("\0", 12), + key = string_rep("\0", 16), + nonce = string_rep("\0", 12), plaintext = "", aad = "", ciphertext = "", @@ -815,9 +872,9 @@ local test_vectors = { }, { name = "NIST Test Case 2 (AES-128-GCM)", - key = string.rep("\0", 16), - nonce = string.rep("\0", 12), - plaintext = string.rep("\0", 16), + key = string_rep("\0", 16), + nonce = string_rep("\0", 12), + plaintext = string_rep("\0", 16), aad = "", ciphertext = bytes.from_hex("0388dace60b6a392f328c2b971b2fe78"), tag = bytes.from_hex("ab6e47d42cec13bdf53a67b21257bddf"), @@ -858,8 +915,8 @@ function aes_gcm.selftest() if test.ciphertext then -- Test with known ciphertext and tag local result = aes_gcm.encrypt(test.key, test.nonce, test.plaintext, test.aad) - local result_ct = string.sub(result, 1, #test.ciphertext) - local result_tag = string.sub(result, #test.ciphertext + 1) + local result_ct = string_sub(result, 1, #test.ciphertext) + local result_tag = string_sub(result, #test.ciphertext + 1) if result_ct == test.ciphertext and result_tag == test.tag then print(" ✅ PASS: Encryption") @@ -878,7 +935,7 @@ function aes_gcm.selftest() print(" ❌ FAIL: Encryption") print(" Expected CT: " .. bytes.to_hex(test.ciphertext)) print(" Got CT: " .. bytes.to_hex(result_ct)) - print(" Expected Tag: " .. bytes.to_hex(test.tag)) + print(" Expected Tag: " .. (test.tag and bytes.to_hex(test.tag) or "none")) print(" Got Tag: " .. bytes.to_hex(result_tag)) end else @@ -908,8 +965,8 @@ function aes_gcm.selftest() -- Test 1: Basic encryption/decryption with AES-128 total = total + 1 - local key128 = string.rep(string.char(0x42), 16) - local nonce = string.rep("\0", 11) .. string.char(0x01) + local key128 = string_rep(string_char(0x42), 16) + local nonce = string_rep("\0", 11) .. string_char(0x01) local aad = "user@example.com|2024-01-01" local plaintext = "This is a secret message that needs both encryption and authentication." @@ -925,7 +982,7 @@ function aes_gcm.selftest() -- Test 2: Basic encryption/decryption with AES-256 total = total + 1 - local key256 = string.rep(string.char(0x43), 32) + local key256 = string_rep(string_char(0x43), 32) local ct256 = aes_gcm.encrypt(key256, nonce, plaintext, aad) local pt256 = aes_gcm.decrypt(key256, nonce, ct256, aad) @@ -938,7 +995,7 @@ function aes_gcm.selftest() -- Test 3: Authentication tag tampering detection total = total + 1 - local tampered = ciphertext_and_tag:sub(1, -2) .. string.char(255) + local tampered = ciphertext_and_tag:sub(1, -2) .. string_char(255) local tampered_result = aes_gcm.decrypt(key128, nonce, tampered, aad) if tampered_result == nil then @@ -962,7 +1019,7 @@ function aes_gcm.selftest() -- Test 5: Nonce uniqueness total = total + 1 - local nonce2 = string.rep("\0", 11) .. string.char(0x02) + local nonce2 = string_rep("\0", 11) .. string_char(0x02) local ciphertext2 = aes_gcm.encrypt(key128, nonce2, plaintext, aad) if ciphertext_and_tag ~= ciphertext2 then @@ -998,7 +1055,7 @@ function aes_gcm.selftest() -- Test 8: Ciphertext tampering detection total = total + 1 - local tampered_ct = string.char(255) .. ciphertext_and_tag:sub(2) + local tampered_ct = string_char(255) .. ciphertext_and_tag:sub(2) local tampered_ct_result = aes_gcm.decrypt(key128, nonce, tampered_ct, aad) if tampered_ct_result == nil then @@ -1010,7 +1067,7 @@ function aes_gcm.selftest() -- Test 9: Wrong key detection total = total + 1 - local wrong_key = string.rep(string.char(0x99), 16) + local wrong_key = string_rep(string_char(0x99), 16) local wrong_key_result = aes_gcm.decrypt(wrong_key, nonce, ciphertext_and_tag, aad) if wrong_key_result == nil then @@ -1022,7 +1079,7 @@ function aes_gcm.selftest() -- Test 10: Large plaintext (multiple blocks) total = total + 1 - local large_plaintext = string.rep("A", 1000) + local large_plaintext = string_rep("A", 1000) local large_ct = aes_gcm.encrypt(key128, nonce, large_plaintext, aad) local large_pt = aes_gcm.decrypt(key128, nonce, large_ct, aad) @@ -1035,7 +1092,7 @@ function aes_gcm.selftest() -- Test 11: Different key sizes produce different outputs total = total + 1 - local key192 = string.rep(string.char(0x44), 24) + local key192 = string_rep(string_char(0x44), 24) local ct128 = aes_gcm.encrypt(key128, nonce, plaintext, aad) local ct192 = aes_gcm.encrypt(key192, nonce, plaintext, aad) local ct256_2 = aes_gcm.encrypt(key256, nonce, plaintext, aad) @@ -1067,9 +1124,9 @@ function aes_gcm.benchmark() local key256 = bytes.from_hex("feffe9928665731c6d6a8f9467308308feffe9928665731c6d6a8f9467308308") local nonce = bytes.from_hex("cafebabefacedbaddecaf888") local aad = "feedfacedeadbeeffeedfacedeadbeefabaddad2" - local plaintext_64 = string.rep("a", 64) - local plaintext_1k = string.rep("a", 1024) - local plaintext_8k = string.rep("a", 8192) + local plaintext_64 = string_rep("a", 64) + local plaintext_1k = string_rep("a", 1024) + local plaintext_8k = string_rep("a", 8192) print("AES-128-GCM Encryption:") benchmark_op("aes128_encrypt_64_bytes", function() diff --git a/src/noiseprotocol/crypto/blake2.lua b/src/noiseprotocol/crypto/blake2.lua index 8c59264..a4fd561 100644 --- a/src/noiseprotocol/crypto/blake2.lua +++ b/src/noiseprotocol/crypto/blake2.lua @@ -1,8 +1,9 @@ --- @module "noiseprotocol.crypto.blake2" --- Pure Lua BLAKE2s and BLAKE2b Implementation for portability. +--- @class noiseprotocol.crypto.blake2 local blake2 = {} -local bitn = require("vendor.bitn") +local bitn = require("bitn") local bit32 = bitn.bit32 local bit64 = bitn.bit64 @@ -11,6 +12,19 @@ local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op +-- Local references for performance (avoid module table lookups in hot loops) +local bit32_add = bit32.add +local bit32_bxor = bit32.bxor +local bit32_ror = bit32.ror +local bit64_add = bit64.add +local bit64_xor = bit64.xor +local bit64_ror = bit64.ror +local bit64_new = bit64.new +local string_byte = string.byte +local string_char = string.char +local string_rep = string.rep +local table_concat = table.concat + -- BLAKE2s initialization vectors (first 32 bits of fractional parts of square roots of first 8 primes) --- @type HashState local BLAKE2S_IV = { @@ -58,6 +72,34 @@ local BLAKE2S_SIGMA = { -- BLAKE2b permutation table (same as BLAKE2s) local BLAKE2B_SIGMA = BLAKE2S_SIGMA +--- Initialize a 16-element BLAKE2s working vector with zeros +--- @return Blake2sVector16 array Initialized array +local function create_blake2s_vector() + local arr = {} + for i = 1, 16 do + arr[i] = 0 + end + --- @cast arr Blake2sVector16 + return arr +end + +--- Initialize a 16-element BLAKE2b working vector with zeros +--- @return Blake2bVector16 array Initialized array +local function create_blake2b_vector() + local arr = {} + for i = 1, 16 do + arr[i] = bit64_new(0, 0) + end + --- @cast arr Blake2bVector16 + return arr +end + +-- Pre-allocated arrays for blake2s_compress() to avoid repeated allocation +local blake2s_v = create_blake2s_vector() + +-- Pre-allocated arrays for blake2b_compress() to avoid repeated allocation +local blake2b_v = create_blake2b_vector() + --- BLAKE2s G function --- @param v Blake2sVector16 Working vector --- @param a integer Index a @@ -67,14 +109,14 @@ local BLAKE2B_SIGMA = BLAKE2S_SIGMA --- @param x integer Message word x --- @param y integer Message word y local function blake2s_g(v, a, b, c, d, x, y) - v[a] = bit32.add(bit32.add(v[a], v[b]), x) - v[d] = bit32.ror(bit32.bxor(v[d], v[a]), 16) - v[c] = bit32.add(v[c], v[d]) - v[b] = bit32.ror(bit32.bxor(v[b], v[c]), 12) - v[a] = bit32.add(bit32.add(v[a], v[b]), y) - v[d] = bit32.ror(bit32.bxor(v[d], v[a]), 8) - v[c] = bit32.add(v[c], v[d]) - v[b] = bit32.ror(bit32.bxor(v[b], v[c]), 7) + v[a] = bit32_add(bit32_add(v[a], v[b]), x) + v[d] = bit32_ror(bit32_bxor(v[d], v[a]), 16) + v[c] = bit32_add(v[c], v[d]) + v[b] = bit32_ror(bit32_bxor(v[b], v[c]), 12) + v[a] = bit32_add(bit32_add(v[a], v[b]), y) + v[d] = bit32_ror(bit32_bxor(v[d], v[a]), 8) + v[c] = bit32_add(v[c], v[d]) + v[b] = bit32_ror(bit32_bxor(v[b], v[c]), 7) end --- BLAKE2b G function @@ -86,14 +128,14 @@ end --- @param x table Message word x --- @param y table Message word y local function blake2b_g(v, a, b, c, d, x, y) - v[a] = bit64.add(bit64.add(v[a], v[b]), x) - v[d] = bit64.ror(bit64.xor(v[d], v[a]), 32) - v[c] = bit64.add(v[c], v[d]) - v[b] = bit64.ror(bit64.xor(v[b], v[c]), 24) - v[a] = bit64.add(bit64.add(v[a], v[b]), y) - v[d] = bit64.ror(bit64.xor(v[d], v[a]), 16) - v[c] = bit64.add(v[c], v[d]) - v[b] = bit64.ror(bit64.xor(v[b], v[c]), 63) + v[a] = bit64_add(bit64_add(v[a], v[b]), x) + v[d] = bit64_ror(bit64_xor(v[d], v[a]), 32) + v[c] = bit64_add(v[c], v[d]) + v[b] = bit64_ror(bit64_xor(v[b], v[c]), 24) + v[a] = bit64_add(bit64_add(v[a], v[b]), y) + v[d] = bit64_ror(bit64_xor(v[d], v[a]), 16) + v[c] = bit64_add(v[c], v[d]) + v[b] = bit64_ror(bit64_xor(v[b], v[c]), 63) end --- BLAKE2s compression function @@ -103,9 +145,8 @@ end --- @param th integer Counter (high 32 bits) --- @param f boolean Final block flag local function blake2s_compress(h, m, t, th, f) - -- Initialize working vector - --- @type Blake2sVector16 - local v = {} + -- Reuse pre-allocated working vector + local v = blake2s_v -- First half from hash state for i = 1, 8 do @@ -118,10 +159,10 @@ local function blake2s_compress(h, m, t, th, f) end -- Mix in counter and final flag - v[13] = bit32.bxor(v[13], t) -- Low 32 bits of counter - v[14] = bit32.bxor(v[14], th) -- High 32 bits of counter + v[13] = bit32_bxor(v[13], t) -- Low 32 bits of counter + v[14] = bit32_bxor(v[14], th) -- High 32 bits of counter if f then - v[15] = bit32.bxor(v[15], 0xFFFFFFFF) -- Invert all bits for final block + v[15] = bit32_bxor(v[15], 0xFFFFFFFF) -- Invert all bits for final block end -- 10 rounds @@ -144,7 +185,7 @@ local function blake2s_compress(h, m, t, th, f) -- Finalize for i = 1, 8 do - h[i] = bit32.bxor(bit32.bxor(h[i], v[i]), v[i + 8]) + h[i] = bit32_bxor(bit32_bxor(h[i], v[i]), v[i + 8]) end end @@ -154,25 +195,24 @@ end --- @param t table Counter (64-bit) --- @param f boolean Final block flag local function blake2b_compress(h, m, t, f) - -- Initialize working vector - --- @type Blake2bVector16 - local v = {} + -- Reuse pre-allocated working vector + local v = blake2b_v -- First half from hash state for i = 1, 8 do - v[i] = { h[i][1], h[i][2] } + v[i][1], v[i][2] = h[i][1], h[i][2] end -- Second half from IV for i = 1, 8 do - v[8 + i] = { BLAKE2B_IV[i][1], BLAKE2B_IV[i][2] } + v[8 + i][1], v[8 + i][2] = BLAKE2B_IV[i][1], BLAKE2B_IV[i][2] end -- Mix in counter and final flag - v[13] = bit64.xor(v[13], t) - v[14] = bit64.xor(v[14], { 0, 0 }) -- High 64 bits of counter (always 0 for messages < 2^64 bytes) + v[13] = bit64_xor(v[13], t) + v[14] = bit64_xor(v[14], bit64_new(0, 0)) -- High 64 bits of counter (always 0 for messages < 2^64 bytes) if f then - v[15] = bit64.xor(v[15], { 0xffffffff, 0xffffffff }) + v[15] = bit64_xor(v[15], bit64_new(0xffffffff, 0xffffffff)) end -- 12 rounds @@ -195,7 +235,7 @@ local function blake2b_compress(h, m, t, f) -- Finalize for i = 1, 8 do - h[i] = bit64.xor(bit64.xor(h[i], v[i]), v[i + 8]) + h[i] = bit64_xor(bit64_xor(h[i], v[i]), v[i + 8]) end end @@ -220,7 +260,7 @@ function blake2.blake2s(data) -- Parameter block: digest length = 32, key length = 0, fanout = 1, depth = 1 -- All other parameters are 0 (no salt, no personalization, etc.) local param = 32 + (0 * 256) + (1 * 65536) + (1 * 16777216) -- 0x01010020 - h[1] = bit32.bxor(h[1], param) + h[1] = bit32_bxor(h[1], param) local data_len = #data local offset = 1 @@ -253,7 +293,7 @@ function blake2.blake2s(data) -- Pad final block with zeros local final_data = data:sub(offset) - local final_block = final_data .. string.rep("\0", 64 - remaining) + local final_block = final_data .. string_rep("\0", 64 - remaining) --- @type Blake2sVector16 local m = {} @@ -279,7 +319,7 @@ function blake2.blake2s(data) result_bytes[i] = bytes.u32_to_le_bytes(h[i]) end - return table.concat(result_bytes) + return table_concat(result_bytes) end --- Compute BLAKE2b hash of input data @@ -307,15 +347,15 @@ function blake2.blake2b(data) -- In little-endian 64-bit: 0x0000000001010040 -- Split into two 32-bit words (little-endian): low=0x01010040, high=0x00000000 -- But our u64 format is {high, low}, so we need {0x00000000, 0x01010040} - h[1] = bit64.xor(h[1], { 0x00000000, 0x01010040 }) + h[1] = bit64_xor(h[1], bit64_new(0x00000000, 0x01010040)) local data_len = #data local offset = 1 - local counter = { 0, 0 } + local counter = bit64_new(0, 0) -- Process full 128-byte blocks while offset + 127 <= data_len do - counter = bit64.add(counter, { 0, 128 }) + counter = bit64_add(counter, bit64_new(0, 128)) -- Check if this is the last block local is_last_block = (offset + 128 > data_len) @@ -334,10 +374,10 @@ function blake2.blake2b(data) -- Process final block (if there's remaining data) local remaining = data_len - offset + 1 if remaining > 0 then - counter = bit64.add(counter, { 0, remaining }) + counter = bit64_add(counter, bit64_new(0, remaining)) -- Pad final block with zeros - local final_block = data:sub(offset) .. string.rep("\0", 128 - remaining) + local final_block = data:sub(offset) .. string_rep("\0", 128 - remaining) --- @type Blake2bVector16 local m = {} @@ -351,9 +391,9 @@ function blake2.blake2b(data) --- @type Blake2bVector16 local m = {} for i = 1, 16 do - m[i] = { 0, 0 } + m[i] = bit64_new(0, 0) end - blake2b_compress(h, m, { 0, 0 }, true) + blake2b_compress(h, m, bit64_new(0, 0), true) end -- Produce final hash value as binary string (optimized with table) @@ -362,7 +402,7 @@ function blake2.blake2b(data) result_bytes[i] = bytes.u64_to_le_bytes(h[i]) end - return table.concat(result_bytes) + return table_concat(result_bytes) end --- Compute BLAKE2s hash and return as hex string @@ -403,19 +443,19 @@ function blake2.hmac_blake2s(key, data) -- Keys shorter than blocksize are right-padded with zeros if #key < block_size then - key = key .. string.rep("\0", block_size - #key) + key = key .. string_rep("\0", block_size - #key) end -- Compute inner and outer padding (optimized with table) local ipad_bytes = {} local opad_bytes = {} for i = 1, block_size do - local byte = string.byte(key, i) - ipad_bytes[i] = string.char(bit32.bxor(byte, 0x36)) - opad_bytes[i] = string.char(bit32.bxor(byte, 0x5C)) + local byte = string_byte(key, i) + ipad_bytes[i] = string_char(bit32_bxor(byte, 0x36)) + opad_bytes[i] = string_char(bit32_bxor(byte, 0x5C)) end - local ipad = table.concat(ipad_bytes) - local opad = table.concat(opad_bytes) + local ipad = table_concat(ipad_bytes) + local opad = table_concat(opad_bytes) -- Compute HMAC = H(opad || H(ipad || data)) local inner_hash = blake2.blake2s(ipad .. data) @@ -446,19 +486,19 @@ function blake2.hmac_blake2b(key, data) -- Keys shorter than blocksize are right-padded with zeros if #key < block_size then - key = key .. string.rep("\0", block_size - #key) + key = key .. string_rep("\0", block_size - #key) end -- Compute inner and outer padding (optimized with table) local ipad_bytes = {} local opad_bytes = {} for i = 1, block_size do - local byte = string.byte(key, i) - ipad_bytes[i] = string.char(bit32.bxor(byte, 0x36)) - opad_bytes[i] = string.char(bit32.bxor(byte, 0x5C)) + local byte = string_byte(key, i) + ipad_bytes[i] = string_char(bit32_bxor(byte, 0x36)) + opad_bytes[i] = string_char(bit32_bxor(byte, 0x5C)) end - local ipad = table.concat(ipad_bytes) - local opad = table.concat(opad_bytes) + local ipad = table_concat(ipad_bytes) + local opad = table_concat(opad_bytes) -- Compute HMAC = H(opad || H(ipad || data)) local inner_hash = blake2.blake2b(ipad .. data) @@ -536,7 +576,7 @@ local hmac_blake2s_test_vectors = { }, { name = "RFC 4231 Test Case 1 pattern", - key = string.rep(string.char(0x0b), 20), + key = string_rep(string_char(0x0b), 20), message = "Hi There", expected = "65a8b7c5cc9136d424e82c37e2707e74e913c0655b99c75f40edf387453a3260", }, @@ -548,13 +588,13 @@ local hmac_blake2s_test_vectors = { }, { name = "Key = block size (64 bytes)", - key = string.rep("a", 64), + key = string_rep("a", 64), message = "Test message", expected = "12d0e782ae473d8007d33ae6e5244afcaf9239f6a7d5476c69060c01383d6b58", }, { name = "Key > block size (80 bytes)", - key = string.rep("a", 80), + key = string_rep("a", 80), message = "Test message", expected = "41da357bda1107f9fad1a504b5afbe75f5ead5ed7cf8f82e59e18c5e9e653882", }, @@ -575,7 +615,7 @@ local hmac_blake2b_test_vectors = { }, { name = "RFC 4231 Test Case 1 pattern", - key = string.rep(string.char(0x0b), 20), + key = string_rep(string_char(0x0b), 20), message = "Hi There", expected = "358a6a184924894fc34bee5680eedf57d84a37bb38832f288e3b27dc63a98cc8c91e76da476b508bc6b2d408a248857452906e4a20b48c6b4b55d2df0fe1dd24", }, @@ -587,13 +627,13 @@ local hmac_blake2b_test_vectors = { }, { name = "Key = block size (128 bytes)", - key = string.rep("a", 128), + key = string_rep("a", 128), message = "Test message", expected = "021a22a3ecf0f1f7a15aca6a5d9704fc99b6a84a627fa53f7ac932a961ffb69b1e68c46981d5b44fd00a7cae75e4ee63d393eec844a8de2dd00e45b5a0d4e275", }, { name = "Key > block size (80 bytes)", - key = string.rep("a", 80), + key = string_rep("a", 80), message = "Test message", expected = "1c8fb6f426d7800000e8d03c141905b33d10a4da16f9c018140955c5cedfa7a017204aaea1f141c1c0d3d942dee04a795a6e589898c1328b717ad6053a7b4790", }, @@ -745,9 +785,9 @@ end --- including BLAKE2s and BLAKE2b hash computation for various message sizes. function blake2.benchmark() -- Test data - local message_64 = string.rep("a", 64) - local message_1k = string.rep("a", 1024) - local message_8k = string.rep("a", 8192) + local message_64 = string_rep("a", 64) + local message_1k = string_rep("a", 1024) + local message_8k = string_rep("a", 8192) local hmac_key = "benchmark_key" print("BLAKE2s Hash Operations:") diff --git a/src/noiseprotocol/crypto/chacha20.lua b/src/noiseprotocol/crypto/chacha20.lua index d48e39f..d341787 100644 --- a/src/noiseprotocol/crypto/chacha20.lua +++ b/src/noiseprotocol/crypto/chacha20.lua @@ -1,14 +1,26 @@ --- @module "noiseprotocol.crypto.chacha20" --- ChaCha20 Stream Cipher Implementation for portability. +--- @class noiseprotocol.crypto.chacha20 local chacha20 = {} -local bit32 = require("vendor.bitn").bit32 +local bit32 = require("bitn").bit32 local openssl_wrapper = require("noiseprotocol.openssl_wrapper") local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op +-- Local references for performance (avoid module table lookups in hot loops) +local bit32_add = bit32.add +local bit32_bxor = bit32.bxor +local bit32_rol = bit32.rol +local floor = math.floor +local min = math.min +local string_byte = string.byte +local string_char = string.char +local string_rep = string.rep +local table_concat = table.concat + -- Type definitions for better type checking --- 16-element array of 32-bit words @@ -41,16 +53,20 @@ local function create_word_array() return arr end +-- Pre-allocated arrays for chacha20_block() to avoid repeated allocation +local block_state = create_word_array() +local block_working = create_word_array() + --- Convert 32-bit word to 4 bytes (little-endian) --- @param word integer 32-bit word --- @return integer, integer, integer, integer bytes Four bytes in little-endian order local function word_to_bytes(word) local byte1 = word % 256 - word = math.floor(word / 256) + word = floor(word * 0.00390625) -- / 256 local byte2 = word % 256 - word = math.floor(word / 256) + word = floor(word * 0.00390625) local byte3 = word % 256 - word = math.floor(word / 256) + word = floor(word * 0.00390625) local byte4 = word % 256 return byte1, byte2, byte3, byte4 @@ -73,31 +89,34 @@ end --- @param c integer Index of third word --- @param d integer Index of fourth word local function quarter_round(state, a, b, c, d) - state[a] = bit32.add(state[a], state[b]) - state[d] = bit32.rol(bit32.bxor(state[d], state[a]), 16) + state[a] = bit32_add(state[a], state[b]) + state[d] = bit32_rol(bit32_bxor(state[d], state[a]), 16) - state[c] = bit32.add(state[c], state[d]) - state[b] = bit32.rol(bit32.bxor(state[b], state[c]), 12) + state[c] = bit32_add(state[c], state[d]) + state[b] = bit32_rol(bit32_bxor(state[b], state[c]), 12) - state[a] = bit32.add(state[a], state[b]) - state[d] = bit32.rol(bit32.bxor(state[d], state[a]), 8) + state[a] = bit32_add(state[a], state[b]) + state[d] = bit32_rol(bit32_bxor(state[d], state[a]), 8) - state[c] = bit32.add(state[c], state[d]) - state[b] = bit32.rol(bit32.bxor(state[b], state[c]), 7) + state[c] = bit32_add(state[c], state[d]) + state[b] = bit32_rol(bit32_bxor(state[b], state[c]), 7) end ---- Initialize ChaCha20 state with key, nonce, and counter +--- Generate one 64-byte block of ChaCha20 keystream --- @param key string 32-byte key --- @param nonce string 12-byte nonce --- @param counter integer 32-bit counter value ---- @return Word32Array state Initialized 16-word state -local function chacha20_init(key, nonce, counter) +--- @return string keystream 64-byte keystream block +local function chacha20_block(key, nonce, counter) + -- Reuse pre-allocated arrays + local state = block_state + local working_state = block_working + + -- Initialize state inline (avoiding function call overhead) assert(#key == 32, "Key must be exactly 32 bytes") assert(#nonce == 12, "Nonce must be exactly 12 bytes") assert(counter >= 0 and counter < 0x100000000, "Counter must be a valid 32-bit integer") - local state = create_word_array() - -- ChaCha20 constants "expand 32-byte k" state[1] = 0x61707865 -- "expa" state[2] = 0x3320646e -- "nd 3" @@ -108,10 +127,10 @@ local function chacha20_init(key, nonce, counter) for i = 1, 8 do local base = (i - 1) * 4 state[4 + i] = bytes_to_word( - string.byte(key, base + 1), - string.byte(key, base + 2), - string.byte(key, base + 3), - string.byte(key, base + 4) + string_byte(key, base + 1), + string_byte(key, base + 2), + string_byte(key, base + 3), + string_byte(key, base + 4) ) end @@ -122,26 +141,14 @@ local function chacha20_init(key, nonce, counter) for i = 1, 3 do local base = (i - 1) * 4 state[13 + i] = bytes_to_word( - string.byte(nonce, base + 1), - string.byte(nonce, base + 2), - string.byte(nonce, base + 3), - string.byte(nonce, base + 4) + string_byte(nonce, base + 1), + string_byte(nonce, base + 2), + string_byte(nonce, base + 3), + string_byte(nonce, base + 4) ) end - return state -end - ---- Generate one 64-byte block of ChaCha20 keystream ---- @param key string 32-byte key ---- @param nonce string 12-byte nonce ---- @param counter integer 32-bit counter value ---- @return string keystream 64-byte keystream block -local function chacha20_block(key, nonce, counter) - local state = chacha20_init(key, nonce, counter) - -- Create working copy of state - local working_state = create_word_array() for i = 1, 16 do working_state[i] = state[i] end @@ -163,17 +170,17 @@ local function chacha20_block(key, nonce, counter) -- Add original state to working state for i = 1, 16 do - working_state[i] = bit32.add(working_state[i], state[i]) + working_state[i] = bit32_add(working_state[i], state[i]) end - -- Convert state to byte string (little-endian) - optimized with table + -- Convert state to byte string (little-endian) - optimized with local references local result_bytes = {} for i = 1, 16 do local b1, b2, b3, b4 = word_to_bytes(working_state[i]) - result_bytes[i] = string.char(b1, b2, b3, b4) + result_bytes[i] = string_char(b1, b2, b3, b4) end - return table.concat(result_bytes) + return table_concat(result_bytes) end --- ChaCha20 encryption/decryption (same operation) @@ -194,12 +201,12 @@ function chacha20.crypt(key, nonce, plaintext, counter) -- Generate keystream block local keystream = chacha20_block(key, nonce, counter) - -- XOR with plaintext (optimized with table) - local block_size = math.min(64, data_len - offset + 1) + -- XOR with plaintext (optimized with local references) + local block_size = min(64, data_len - offset + 1) for i = 1, block_size do - local plaintext_byte = string.byte(plaintext, offset + i - 1) - local keystream_byte = string.byte(keystream, i) - result_bytes[result_idx] = string.char(bit32.bxor(plaintext_byte, keystream_byte)) + local plaintext_byte = string_byte(plaintext, offset + i - 1) + local keystream_byte = string_byte(keystream, i) + result_bytes[result_idx] = string_char(bit32_bxor(plaintext_byte, keystream_byte)) result_idx = result_idx + 1 end @@ -207,7 +214,7 @@ function chacha20.crypt(key, nonce, plaintext, counter) counter = counter + 1 end - return table.concat(result_bytes) + return table_concat(result_bytes) end --- Convenience function for encryption (same as crypt) @@ -221,7 +228,7 @@ function chacha20.encrypt(key, nonce, plaintext, counter) local openssl = openssl_wrapper.get() if openssl and #plaintext > 0 then -- Prepend 32-bit counter to 96-bit nonce for complete 128-bit nonce - nonce = utils.bytes.u32_to_le_bytes(counter or 1) .. nonce + nonce = bytes.u32_to_le_bytes(counter or 1) .. nonce return openssl.cipher.encrypt("chacha20", plaintext, key, nonce) end return chacha20.crypt(key, nonce, plaintext, counter) @@ -238,7 +245,7 @@ function chacha20.decrypt(key, nonce, ciphertext, counter) local openssl = openssl_wrapper.get() if openssl and #ciphertext > 0 then -- Prepend 32-bit counter to 96-bit nonce for complete 128-bit nonce - nonce = utils.bytes.u32_to_le_bytes(counter or 1) .. nonce + nonce = bytes.u32_to_le_bytes(counter or 1) .. nonce return openssl.cipher.decrypt("chacha20", ciphertext, key, nonce) end return chacha20.crypt(key, nonce, ciphertext, counter) @@ -278,10 +285,10 @@ local test_vectors = { }, { name = "Zero key test", - key = string.rep("\0", 32), - nonce = string.rep("\0", 12), + key = string_rep("\0", 32), + nonce = string_rep("\0", 12), counter = 0, - plaintext = string.rep("\0", 64), + plaintext = string_rep("\0", 64), expected_ciphertext = bytes.from_hex( "76b8e0ada0f13d90405d6ae55386bd28bdd219b8a08ded1aa836efcc8b770dc7da41597c5157488d7724e03fb8d84a376a43b8f41518a11cc387b669b2ee6586" ), @@ -322,11 +329,11 @@ function chacha20.selftest() -- Show first few bytes for debugging local expected_hex = "" local result_hex = "" - local show_bytes = math.min(16, #test.expected_keystream) + local show_bytes = min(16, #test.expected_keystream) for j = 1, show_bytes do - expected_hex = expected_hex .. string.format("%02x", string.byte(test.expected_keystream, j)) - result_hex = result_hex .. string.format("%02x", string.byte(keystream, j)) + expected_hex = expected_hex .. string.format("%02x", string_byte(assert(test.expected_keystream), j)) + result_hex = result_hex .. string.format("%02x", string_byte(keystream, j)) end print(" Expected (first " .. show_bytes .. " bytes): " .. expected_hex) @@ -351,11 +358,11 @@ function chacha20.selftest() -- Show first few bytes for debugging local expected_hex = "" local result_hex = "" - local show_bytes = math.min(16, #test.expected_ciphertext) + local show_bytes = min(16, #test.expected_ciphertext) for j = 1, show_bytes do - expected_hex = expected_hex .. string.format("%02x", string.byte(test.expected_ciphertext, j)) - result_hex = result_hex .. string.format("%02x", string.byte(result, j)) + expected_hex = expected_hex .. string.format("%02x", string_byte(assert(test.expected_ciphertext), j)) + result_hex = result_hex .. string.format("%02x", string_byte(result, j)) end print(" Expected (first " .. show_bytes .. " bytes): " .. expected_hex) @@ -379,8 +386,8 @@ function chacha20.selftest() -- Test 1: Basic encryption/decryption total = total + 1 - local key = string.rep(string.char(0x42), 32) - local nonce = string.rep("\0", 12) + local key = string_rep(string_char(0x42), 32) + local nonce = string_rep("\0", 12) local counter = 1 local plaintext = "Hello, ChaCha20! This is a test message for encryption." @@ -407,7 +414,7 @@ function chacha20.selftest() -- Test 3: Different nonces produce different output total = total + 1 - local nonce2 = string.char(0x01) .. string.rep("\0", 11) + local nonce2 = string_char(0x01) .. string_rep("\0", 11) local ciphertext3 = chacha20.encrypt(key, nonce2, plaintext, counter) if ciphertext ~= ciphertext3 then @@ -442,7 +449,7 @@ function chacha20.selftest() -- Test 6: Large plaintext (multi-block) total = total + 1 - local large_plaintext = string.rep("A", 256) -- 4 blocks + local large_plaintext = string_rep("A", 256) -- 4 blocks local large_ct = chacha20.encrypt(key, nonce, large_plaintext, counter) local large_pt = chacha20.decrypt(key, nonce, large_ct, counter) @@ -455,7 +462,7 @@ function chacha20.selftest() -- Test 7: Partial block total = total + 1 - local partial_plaintext = string.rep("B", 100) -- Not a multiple of 64 + local partial_plaintext = string_rep("B", 100) -- Not a multiple of 64 local partial_ct = chacha20.encrypt(key, nonce, partial_plaintext, counter) local partial_pt = chacha20.decrypt(key, nonce, partial_ct, counter) @@ -485,9 +492,9 @@ function chacha20.benchmark() -- Test data local key = bytes.from_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f") local nonce = bytes.from_hex("000000090000004a00000000") - local plaintext_64 = string.rep("a", 64) - local plaintext_1k = string.rep("a", 1024) - local plaintext_8k = string.rep("a", 8192) + local plaintext_64 = string_rep("a", 64) + local plaintext_1k = string_rep("a", 1024) + local plaintext_8k = string_rep("a", 8192) print("Encryption Operations:") benchmark_op("encrypt_64_bytes", function() diff --git a/src/noiseprotocol/crypto/chacha20_poly1305.lua b/src/noiseprotocol/crypto/chacha20_poly1305.lua index 686c516..20e92c5 100644 --- a/src/noiseprotocol/crypto/chacha20_poly1305.lua +++ b/src/noiseprotocol/crypto/chacha20_poly1305.lua @@ -1,5 +1,6 @@ --- @module "noiseprotocol.crypto.chacha20_poly1305" --- ChaCha20-Poly1305 Authenticated Encryption with Associated Data (AEAD) Implementation for portability. +--- @class noiseprotocol.crypto.chacha20_poly1305 local chacha20_poly1305 = {} local openssl_wrapper = require("noiseprotocol.openssl_wrapper") @@ -9,6 +10,12 @@ local benchmark_op = utils.benchmark.benchmark_op local chacha20 = require("noiseprotocol.crypto.chacha20") local poly1305 = require("noiseprotocol.crypto.poly1305") +-- Local references for performance (avoid module table lookups in hot loops) +local string_char = string.char +local string_rep = string.rep +local string_sub = string.sub +local table_concat = table.concat + --- Generate Poly1305 one-time key using ChaCha20 --- @param key string 32-byte ChaCha20 key --- @param nonce string 12-byte nonce @@ -16,7 +23,7 @@ local poly1305 = require("noiseprotocol.crypto.poly1305") local function poly1305_key_gen(key, nonce) -- Generate Poly1305 key by encrypting 32 zero bytes with ChaCha20 -- Counter starts at 0 for key generation - local zero_block = string.rep("\0", 32) + local zero_block = string_rep("\0", 32) return chacha20.crypt(key, nonce, zero_block, 0) end @@ -37,7 +44,7 @@ local function construct_aad_data(aad, ciphertext) bytes.u64_to_le_bytes(ciphertext_len), } - return table.concat(auth_parts) + return table_concat(auth_parts) end -- ============================================================================ @@ -119,8 +126,8 @@ function chacha20_poly1305.decrypt(key, nonce, ciphertext_and_tag, aad) -- Step 1: Split ciphertext and tag local ciphertext_len = #ciphertext_and_tag - 16 - local ciphertext = string.sub(ciphertext_and_tag, 1, ciphertext_len) - local received_tag = string.sub(ciphertext_and_tag, ciphertext_len + 1) + local ciphertext = string_sub(ciphertext_and_tag, 1, ciphertext_len) + local received_tag = string_sub(ciphertext_and_tag, ciphertext_len + 1) local openssl = openssl_wrapper.get(openssl_wrapper.Feature.AAD) if openssl then @@ -198,14 +205,14 @@ local test_vectors = { }, { name = "Empty AAD roundtrip test", - key = string.char(0x42) .. string.rep("\0", 31), - nonce = string.rep("\0", 12), + key = string_char(0x42) .. string_rep("\0", 31), + nonce = string_rep("\0", 12), aad = "", plaintext = "No additional data", }, { name = "Empty plaintext roundtrip test", - key = string.rep(string.char(0xff), 32), + key = string_rep(string_char(0xff), 32), nonce = bytes.from_hex("0102030405060708090a0b0c"), aad = "Only authenticating this data", plaintext = "", @@ -311,8 +318,8 @@ function chacha20_poly1305.selftest() -- Test 1: Basic encryption/decryption total = total + 1 - local key = string.rep(string.char(0x42), 32) - local nonce = string.rep("\0", 11) .. string.char(0x01) + local key = string_rep(string_char(0x42), 32) + local nonce = string_rep("\0", 11) .. string_char(0x01) local aad = "user@example.com|2024-01-01" local plaintext = "This is a secret message that needs both encryption and authentication." @@ -328,7 +335,7 @@ function chacha20_poly1305.selftest() -- Test 2: Authentication tag tampering detection total = total + 1 - local tampered = ciphertext_and_tag:sub(1, -2) .. string.char(255) + local tampered = ciphertext_and_tag:sub(1, -2) .. string_char(255) local tampered_result = chacha20_poly1305.decrypt(key, nonce, tampered, aad) if tampered_result == nil then @@ -352,7 +359,7 @@ function chacha20_poly1305.selftest() -- Test 4: Nonce uniqueness total = total + 1 - local nonce2 = string.rep("\0", 11) .. string.char(0x02) + local nonce2 = string_rep("\0", 11) .. string_char(0x02) local ciphertext2 = chacha20_poly1305.encrypt(key, nonce2, plaintext, aad) if ciphertext_and_tag ~= ciphertext2 then @@ -388,7 +395,7 @@ function chacha20_poly1305.selftest() -- Test 7: Ciphertext tampering detection total = total + 1 - local tampered_ct = string.char(255) .. ciphertext_and_tag:sub(2) + local tampered_ct = string_char(255) .. ciphertext_and_tag:sub(2) local tampered_ct_result = chacha20_poly1305.decrypt(key, nonce, tampered_ct, aad) if tampered_ct_result == nil then @@ -418,9 +425,9 @@ function chacha20_poly1305.benchmark() local key = bytes.from_hex("808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f") local nonce = bytes.from_hex("070000004041424344454647") local aad = "Additional authenticated data" - local plaintext_64 = string.rep("a", 64) - local plaintext_1k = string.rep("a", 1024) - local plaintext_8k = string.rep("a", 8192) + local plaintext_64 = string_rep("a", 64) + local plaintext_1k = string_rep("a", 1024) + local plaintext_8k = string_rep("a", 8192) print("Authenticated Encryption Operations:") benchmark_op("encrypt_64_bytes", function() diff --git a/src/noiseprotocol/crypto/init.lua b/src/noiseprotocol/crypto/init.lua index ad69f8a..c7f1064 100644 --- a/src/noiseprotocol/crypto/init.lua +++ b/src/noiseprotocol/crypto/init.lua @@ -1,22 +1,38 @@ --- @module "noiseprotocol.crypto" +--- Cryptographic primitives for the Noise Protocol Framework +--- @class noiseprotocol.crypto local crypto = { -- Hash functions + + --- @type noiseprotocol.crypto.sha256 sha256 = require("noiseprotocol.crypto.sha256"), + --- @type noiseprotocol.crypto.sha512 sha512 = require("noiseprotocol.crypto.sha512"), + --- @type noiseprotocol.crypto.blake2 blake2 = require("noiseprotocol.crypto.blake2"), -- AEAD ciphers + + --- @type noiseprotocol.crypto.chacha20_poly1305 chacha20_poly1305 = require("noiseprotocol.crypto.chacha20_poly1305"), + --- @type noiseprotocol.crypto.aes_gcm aes_gcm = require("noiseprotocol.crypto.aes_gcm"), -- Stream ciphers + + --- @type noiseprotocol.crypto.chacha20 chacha20 = require("noiseprotocol.crypto.chacha20"), -- MAC + + --- @type noiseprotocol.crypto.poly1305 poly1305 = require("noiseprotocol.crypto.poly1305"), -- DH functions + + --- @type noiseprotocol.crypto.x25519 x25519 = require("noiseprotocol.crypto.x25519"), + --- @type noiseprotocol.crypto.x448 x448 = require("noiseprotocol.crypto.x448"), } diff --git a/src/noiseprotocol/crypto/poly1305.lua b/src/noiseprotocol/crypto/poly1305.lua index 6638cd3..71e68ba 100644 --- a/src/noiseprotocol/crypto/poly1305.lua +++ b/src/noiseprotocol/crypto/poly1305.lua @@ -1,13 +1,23 @@ --- @module "noiseprotocol.crypto.poly1305" --- Poly1305 Message Authentication Code (MAC) Implementation for portability. +--- @class noiseprotocol.crypto.poly1305 local poly1305 = {} -local bit32 = require("vendor.bitn").bit32 +local bit32 = require("bitn").bit32 local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op +-- Local references for performance (avoid module table lookups in hot loops) +local bit32_band = bit32.band +local bit32_lshift = bit32.lshift +local floor = math.floor +local string_byte = string.byte +local string_char = string.char +local string_rep = string.rep +local table_concat = table.concat + -- Type definitions for better type checking --- 17-element limb array for 130-bit + overflow @@ -35,11 +45,11 @@ local function reduce_high_order_terms(prod, start_pos, end_pos) local reduction_multiplier = 5 -- Calculate target byte position for the reduction - local target_byte = 1 + math.floor(excess_bits / 8) + local target_byte = 1 + floor(excess_bits / 8) local bit_offset = excess_bits % 8 if bit_offset > 0 then - reduction_multiplier = bit32.lshift(reduction_multiplier, bit_offset) + reduction_multiplier = bit32_lshift(reduction_multiplier, bit_offset) end -- Add reduced value to target position @@ -64,7 +74,7 @@ local function propagate_carries(h) assert(h[i] ~= nil, "Limb array must have at least 17 non-nil elements") carry = carry + h[i] h[i] = carry % 256 - carry = math.floor(carry / 256) + carry = floor(carry / 256) end return carry end @@ -93,7 +103,7 @@ end --- @param h Limb17Array Limb array (modified in place) local function reduce_position_17(h) while h[17] >= 4 do - local high_bits = math.floor(h[17] / 4) + local high_bits = floor(h[17] / 4) h[17] = h[17] % 4 -- high_bits represents coefficient of 2^130, so multiply by 5 @@ -117,7 +127,7 @@ end --- Initialize a 33-element product array with zeros --- @return Limb33Array array Initialized array -local function create_product_array() +local function create_limb33_array() local arr = {} for i = 1, 33 do arr[i] = 0 @@ -126,6 +136,11 @@ local function create_product_array() return arr end +-- Pre-allocated arrays for authenticate() hot loop +local auth_c = create_limb17_array() -- Message block array (17 elements) +local auth_prod = create_limb33_array() -- Product array (33 elements) +local auth_g = create_limb17_array() -- Final reduction array (17 elements) + --- Initialize a 16-element key array --- @param source integer[] Source array to copy from --- @param offset integer Starting offset in source array @@ -158,7 +173,7 @@ function poly1305.authenticate(key, msg) --- @type integer[] local key_bytes = {} for i = 1, #key do - key_bytes[i] = string.byte(key, i) + key_bytes[i] = string_byte(key, i) end -- Extract and clamp r (first 16 bytes) per RFC 7539 @@ -166,13 +181,13 @@ function poly1305.authenticate(key, msg) -- Apply RFC 7539 clamping to ensure r has specific bit patterns -- This prevents certain classes of attacks and ensures key validity - r[4] = bit32.band(r[4], 15) -- Clear top 4 bits of 4th byte - r[5] = bit32.band(r[5], 252) -- Clear bottom 2 bits of 5th byte - r[8] = bit32.band(r[8], 15) -- Clear top 4 bits of 8th byte - r[9] = bit32.band(r[9], 252) -- Clear bottom 2 bits of 9th byte - r[12] = bit32.band(r[12], 15) -- Clear top 4 bits of 12th byte - r[13] = bit32.band(r[13], 252) -- Clear bottom 2 bits of 13th byte - r[16] = bit32.band(r[16], 15) -- Clear top 4 bits of 16th byte + r[4] = bit32_band(r[4], 15) -- Clear top 4 bits of 4th byte + r[5] = bit32_band(r[5], 252) -- Clear bottom 2 bits of 5th byte + r[8] = bit32_band(r[8], 15) -- Clear top 4 bits of 8th byte + r[9] = bit32_band(r[9], 252) -- Clear bottom 2 bits of 9th byte + r[12] = bit32_band(r[12], 15) -- Clear top 4 bits of 12th byte + r[13] = bit32_band(r[13], 252) -- Clear bottom 2 bits of 13th byte + r[16] = bit32_band(r[16], 15) -- Clear top 4 bits of 16th byte -- Extract s (second 16 bytes) - used for final addition local s = create_key_array(key_bytes, 17) @@ -183,12 +198,15 @@ function poly1305.authenticate(key, msg) local msglen = #msg local offset = 1 + -- Reuse pre-allocated arrays for hot loop + local c = auth_c + local prod = auth_prod + -- Process message in 16-byte blocks while msglen >= 16 do - -- Load current 16-byte block - local c = create_limb17_array() + -- Load current 16-byte block (reset and fill) for i = 1, 16 do - c[i] = string.byte(msg, offset + i - 1) + c[i] = string_byte(msg, offset + i - 1) end c[17] = 1 -- Add high bit (represents 2^128 for full blocks) @@ -197,13 +215,15 @@ function poly1305.authenticate(key, msg) for i = 1, 17 do carry = carry + h[i] + c[i] h[i] = carry % 256 - carry = math.floor(carry / 256) + carry = floor(carry / 256) end -- Multiply by r: h = (h * r) mod (2^130 - 5) - -- Step 1: Compute full precision product h * r - local prod = create_product_array() + -- Step 1: Compute full precision product h * r (reset prod first) + for i = 1, 33 do + prod[i] = 0 + end for i = 1, 17 do for j = 1, 16 do @@ -232,11 +252,14 @@ function poly1305.authenticate(key, msg) -- Process final partial block (if any) if msglen > 0 then - local c = create_limb17_array() + -- Reset c array for partial block + for i = 1, 17 do + c[i] = 0 + end -- Load partial block for i = 1, msglen do - c[i] = string.byte(msg, offset + i - 1) + c[i] = string_byte(msg, offset + i - 1) end c[msglen + 1] = 1 -- Add padding bit at end of message @@ -245,11 +268,13 @@ function poly1305.authenticate(key, msg) for i = 1, 17 do carry = carry + h[i] + c[i] h[i] = carry % 256 - carry = math.floor(carry / 256) + carry = floor(carry / 256) end - -- Multiply by r - local prod = create_product_array() + -- Multiply by r (reset prod first) + for i = 1, 33 do + prod[i] = 0 + end for i = 1, 17 do for j = 1, 16 do @@ -271,14 +296,15 @@ function poly1305.authenticate(key, msg) -- Final reduction: conditionally subtract (2^130 - 5) if h >= 2^130 - 5 -- This ensures the result is in canonical form - local g = create_limb17_array() + -- Reuse pre-allocated g array + local g = auth_g for i = 1, 17 do g[i] = h[i] end -- Test reduction by computing h + 5 g[1] = g[1] + 5 - local carry = math.floor(g[1] / 256) + local carry = floor(g[1] / 256) g[1] = g[1] % 256 for i = 2, 17 do @@ -287,7 +313,7 @@ function poly1305.authenticate(key, msg) end carry = carry + g[i] g[i] = carry % 256 - carry = math.floor(carry / 256) + carry = floor(carry / 256) end -- Use mask-based selection for constant-time operation @@ -303,37 +329,37 @@ function poly1305.authenticate(key, msg) carry = 0 for i = 1, 16 do local sum = h[i] + s[i] + carry - result_bytes[i] = string.char(sum % 256) - carry = math.floor(sum / 256) + result_bytes[i] = string_char(sum % 256) + carry = floor(sum / 256) end - return table.concat(result_bytes) + return table_concat(result_bytes) end --- Test vectors from RFC 8439, RFC 7539, and other reference implementations local test_vectors = { { name = "RFC 8439 Test Vector #1 (all zeros)", - key = string.rep("\0", 32), - message = string.rep("\0", 64), - expected = string.rep("\0", 16), + key = string_rep("\0", 32), + message = string_rep("\0", 64), + expected = string_rep("\0", 16), }, { name = "RFC 8439 Test Vector #2 (r=0, long message)", - key = string.rep("\0", 16) .. bytes.from_hex("36e5f6b5c5e06070f0efca96227a863e"), + key = string_rep("\0", 16) .. bytes.from_hex("36e5f6b5c5e06070f0efca96227a863e"), message = 'Any submission to the IETF intended by the Contributor for publication as all or part of an IETF Internet-Draft or RFC and any statement made within the context of an IETF activity is considered an "IETF Contribution". Such statements include oral statements in IETF sessions, as well as written and electronic communications made at any time or place, which are addressed to', expected = bytes.from_hex("36e5f6b5c5e06070f0efca96227a863e"), }, { name = "RFC 8439 Test Vector #3 (r!=0, s=0)", - key = bytes.from_hex("36e5f6b5c5e06070f0efca96227a863e") .. string.rep("\0", 16), + key = bytes.from_hex("36e5f6b5c5e06070f0efca96227a863e") .. string_rep("\0", 16), message = 'Any submission to the IETF intended by the Contributor for publication as all or part of an IETF Internet-Draft or RFC and any statement made within the context of an IETF activity is considered an "IETF Contribution". Such statements include oral statements in IETF sessions, as well as written and electronic communications made at any time or place, which are addressed to', expected = bytes.from_hex("f3477e7cd95417af89a6b8794c310cf0"), }, { name = "Wrap test vector (tests modular reduction edge case)", key = bytes.from_hex("0200000000000000000000000000000000000000000000000000000000000000"), - message = string.rep(string.char(255), 16), + message = string_rep(string_char(255), 16), expected = bytes.from_hex("03000000000000000000000000000000"), }, { @@ -375,11 +401,11 @@ function poly1305.selftest() local expected_hex = "" for j = 1, #result do - result_hex = result_hex .. string.format("%02x", string.byte(result, j)) + result_hex = result_hex .. string.format("%02x", string_byte(result, j)) end for j = 1, #test.expected do - expected_hex = expected_hex .. string.format("%02x", string.byte(test.expected, j)) + expected_hex = expected_hex .. string.format("%02x", string_byte(test.expected, j)) end if result == test.expected then @@ -405,8 +431,8 @@ function poly1305.selftest() -- Test 1: Different keys produce different tags total = total + 1 - local key1 = string.rep(string.char(0x42), 32) - local key2 = string.rep(string.char(0x43), 32) + local key1 = string_rep(string_char(0x42), 32) + local key2 = string_rep(string_char(0x43), 32) local message = "Test message for MAC verification" local tag1 = poly1305.authenticate(key1, message) @@ -447,7 +473,7 @@ function poly1305.selftest() -- Test 4: Large message handling (multi-block) total = total + 1 - local large_msg = string.rep("A", 256) -- 16 full blocks + local large_msg = string_rep("A", 256) -- 16 full blocks local large_tag = poly1305.authenticate(key1, large_msg) if #large_tag == 16 then @@ -459,7 +485,7 @@ function poly1305.selftest() -- Test 5: Partial block handling total = total + 1 - local partial_msg = string.rep("B", 33) -- 2 blocks + 1 byte + local partial_msg = string_rep("B", 33) -- 2 blocks + 1 byte local partial_tag = poly1305.authenticate(key1, partial_msg) if #partial_tag == 16 then @@ -514,9 +540,9 @@ end function poly1305.benchmark() -- Test data local key = bytes.from_hex("85d6be7857556d337f4452fe42d506a80103808afb0db2fd4abff6af4149f51b") - local message_64 = string.rep("a", 64) - local message_1k = string.rep("a", 1024) - local message_8k = string.rep("a", 8192) + local message_64 = string_rep("a", 64) + local message_1k = string_rep("a", 1024) + local message_8k = string_rep("a", 8192) print("MAC Operations:") benchmark_op("mac_64_bytes", function() diff --git a/src/noiseprotocol/crypto/sha256.lua b/src/noiseprotocol/crypto/sha256.lua index 8ae825c..938b52f 100644 --- a/src/noiseprotocol/crypto/sha256.lua +++ b/src/noiseprotocol/crypto/sha256.lua @@ -1,14 +1,29 @@ --- @module "noiseprotocol.crypto.sha256" --- Pure Lua SHA-256 Implementation for portability. +--- @class noiseprotocol.crypto.sha256 local sha256 = {} -local bit32 = require("vendor.bitn").bit32 +local bit32 = require("bitn").bit32 local openssl_wrapper = require("noiseprotocol.openssl_wrapper") local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op +-- Local references for performance (avoid global/module table lookups in hot loops) +local bit32_bxor = bit32.bxor +local bit32_band = bit32.band +local bit32_bnot = bit32.bnot +local bit32_ror = bit32.ror +local bit32_rshift = bit32.rshift +local bit32_add = bit32.add +local bytes_be_bytes_to_u32 = bytes.be_bytes_to_u32 +local bytes_u32_to_be_bytes = bytes.u32_to_be_bytes +local string_char = string.char +local string_rep = string.rep +local string_byte = string.byte +local table_concat = table.concat + -- SHA-256 constants (first 32 bits of fractional parts of cube roots of first 64 primes) --- @type integer[64] local K = { @@ -93,59 +108,72 @@ local H0 = { 0x5be0cd19, } +--- Initialize a 64-element message schedule array with zeros +--- @return integer[] array Initialized array +local function create_message_schedule() + local arr = {} + for i = 1, 64 do + arr[i] = 0 + end + return arr +end + +-- Pre-allocated message schedule array for sha256_chunk() +local chunk_W = create_message_schedule() + --- SHA-256 core compression function --- @param chunk string 64-byte chunk --- @param H HashState Hash state (8 integers) local function sha256_chunk(chunk, H) - -- Prepare message schedule W (pre-allocate full array) - local W = {} + -- Reuse pre-allocated message schedule W + local W = chunk_W - -- First 16 words are the message chunk + -- First 16 words are the message chunk (use local reference) for i = 1, 16 do - W[i] = bytes.be_bytes_to_u32(chunk, (i - 1) * 4 + 1) + W[i] = bytes_be_bytes_to_u32(chunk, (i - 1) * 4 + 1) end -- Extend the first 16 words into the remaining 48 words for i = 17, 64 do local w15 = W[i - 15] local w2 = W[i - 2] - local s0 = bit32.bxor(bit32.ror(w15, 7), bit32.bxor(bit32.ror(w15, 18), bit32.rshift(w15, 3))) - local s1 = bit32.bxor(bit32.ror(w2, 17), bit32.bxor(bit32.ror(w2, 19), bit32.rshift(w2, 10))) - W[i] = bit32.add(bit32.add(bit32.add(W[i - 16], s0), W[i - 7]), s1) + local s0 = bit32_bxor(bit32_ror(w15, 7), bit32_bxor(bit32_ror(w15, 18), bit32_rshift(w15, 3))) + local s1 = bit32_bxor(bit32_ror(w2, 17), bit32_bxor(bit32_ror(w2, 19), bit32_rshift(w2, 10))) + W[i] = bit32_add(bit32_add(bit32_add(W[i - 16], s0), W[i - 7]), s1) end -- Initialize working variables local a, b, c, d, e, f, g, h = H[1], H[2], H[3], H[4], H[5], H[6], H[7], H[8] - -- Main loop (optimized with local variables) + -- Main loop (optimized with local references) for i = 1, 64 do - local prime = K[i] - local S1 = bit32.bxor(bit32.ror(e, 6), bit32.bxor(bit32.ror(e, 11), bit32.ror(e, 25))) - local ch = bit32.bxor(bit32.band(e, f), bit32.band(bit32.bnot(e), g)) - local temp1 = bit32.add(bit32.add(bit32.add(bit32.add(h, S1), ch), prime), W[i]) - local S0 = bit32.bxor(bit32.ror(a, 2), bit32.bxor(bit32.ror(a, 13), bit32.ror(a, 22))) - local maj = bit32.bxor(bit32.band(a, b), bit32.bxor(bit32.band(a, c), bit32.band(b, c))) - local temp2 = bit32.add(S0, maj) + local ki = K[i] + local S1 = bit32_bxor(bit32_ror(e, 6), bit32_bxor(bit32_ror(e, 11), bit32_ror(e, 25))) + local ch = bit32_bxor(bit32_band(e, f), bit32_band(bit32_bnot(e), g)) + local temp1 = bit32_add(bit32_add(bit32_add(bit32_add(h, S1), ch), ki), W[i]) + local S0 = bit32_bxor(bit32_ror(a, 2), bit32_bxor(bit32_ror(a, 13), bit32_ror(a, 22))) + local maj = bit32_bxor(bit32_band(a, b), bit32_bxor(bit32_band(a, c), bit32_band(b, c))) + local temp2 = bit32_add(S0, maj) h = g g = f f = e - e = bit32.add(d, temp1) + e = bit32_add(d, temp1) d = c c = b b = a - a = bit32.add(temp1, temp2) + a = bit32_add(temp1, temp2) end -- Add compressed chunk to current hash value - H[1] = bit32.add(H[1], a) - H[2] = bit32.add(H[2], b) - H[3] = bit32.add(H[3], c) - H[4] = bit32.add(H[4], d) - H[5] = bit32.add(H[5], e) - H[6] = bit32.add(H[6], f) - H[7] = bit32.add(H[7], g) - H[8] = bit32.add(H[8], h) + H[1] = bit32_add(H[1], a) + H[2] = bit32_add(H[2], b) + H[3] = bit32_add(H[3], c) + H[4] = bit32_add(H[4], d) + H[5] = bit32_add(H[5], e) + H[6] = bit32_add(H[6], f) + H[7] = bit32_add(H[7], g) + H[8] = bit32_add(H[8], h) end -- ============================================================================ @@ -172,18 +200,18 @@ function sha256.sha256(data) local msg_len_bits = msg_len * 8 -- Append '1' bit (plus zero padding to make it a byte) - data = data .. string.char(0x80) + data = data .. string_char(0x80) -- Append zeros to make message length ≡ 448 (mod 512) bits = 56 (mod 64) bytes -- Current length is msg_len + 1 (for the 0x80 byte) local current_len = msg_len + 1 local target_len = 56 -- We want to reach 56 bytes before adding the 8-byte length local padding_len = (target_len - current_len) % 64 - data = data .. string.rep("\0", padding_len) + data = data .. string_rep("\0", padding_len) -- Append original length as 64-bit big-endian integer -- For simplicity, we only support messages < 2^32 bits - data = data .. string.rep("\0", 4) .. bytes.u32_to_be_bytes(msg_len_bits) + data = data .. string_rep("\0", 4) .. bytes_u32_to_be_bytes(msg_len_bits) -- Process message in 64-byte chunks for i = 1, #data, 64 do @@ -196,10 +224,10 @@ function sha256.sha256(data) -- Produce final hash value as binary string (optimized with table) local result_bytes = {} for i = 1, 8 do - result_bytes[i] = bytes.u32_to_be_bytes(H[i]) + result_bytes[i] = bytes_u32_to_be_bytes(H[i]) end - return table.concat(result_bytes) + return table_concat(result_bytes) end --- Compute SHA-256 hash and return as hex string @@ -231,19 +259,19 @@ function sha256.hmac_sha256(key, data) -- Keys shorter than blocksize are right-padded with zeros if #key < block_size then - key = key .. string.rep("\0", block_size - #key) + key = key .. string_rep("\0", block_size - #key) end - -- Compute inner and outer padding (optimized with table) + -- Compute inner and outer padding (optimized with local references) local ipad_bytes = {} local opad_bytes = {} for i = 1, block_size do - local byte = string.byte(key, i) - ipad_bytes[i] = string.char(bit32.bxor(byte, 0x36)) - opad_bytes[i] = string.char(bit32.bxor(byte, 0x5C)) + local byte = string_byte(key, i) + ipad_bytes[i] = string_char(bit32_bxor(byte, 0x36)) + opad_bytes[i] = string_char(bit32_bxor(byte, 0x5C)) end - local ipad = table.concat(ipad_bytes) - local opad = table.concat(opad_bytes) + local ipad = table_concat(ipad_bytes) + local opad = table_concat(opad_bytes) -- Compute HMAC = H(opad || H(ipad || data)) local inner_hash = sha256.sha256(ipad .. data) @@ -304,7 +332,7 @@ local test_vectors = { if os.getenv("INCLUDE_SLOW_TESTS") == "1" then table.insert(test_vectors, { name = "Million 'a' characters", - input = string.rep("a", 1000000), + input = string_rep("a", 1000000), expected = "cdc76e5c9914fb9281a1c7e284d73e67f1809a48a497200e046d39ccc7112cd0", }) end @@ -313,7 +341,7 @@ end local hmac_test_vectors = { { name = "HMAC Test Case 1", - key = string.rep(string.char(0x0b), 20), + key = string_rep(string_char(0x0b), 20), data = "Hi There", expected = "b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7", }, @@ -325,8 +353,8 @@ local hmac_test_vectors = { }, { name = "HMAC Test Case 3", - key = string.rep(string.char(0xaa), 20), - data = string.rep(string.char(0xdd), 50), + key = string_rep(string_char(0xaa), 20), + data = string_rep(string_char(0xdd), 50), expected = "773ea91e36800e46854db8ebd09181a72959098b3ef8c122d9635514ced565fe", }, } @@ -446,9 +474,9 @@ end --- including hash computation and HMAC for various message sizes. function sha256.benchmark() -- Test data - local message_64 = string.rep("a", 64) - local message_1k = string.rep("a", 1024) - local message_8k = string.rep("a", 8192) + local message_64 = string_rep("a", 64) + local message_1k = string_rep("a", 1024) + local message_8k = string_rep("a", 8192) local hmac_key = "benchmark_key" print("Hash Operations:") diff --git a/src/noiseprotocol/crypto/sha512.lua b/src/noiseprotocol/crypto/sha512.lua index d4171ba..0cad989 100644 --- a/src/noiseprotocol/crypto/sha512.lua +++ b/src/noiseprotocol/crypto/sha512.lua @@ -1,8 +1,9 @@ --- @module "noiseprotocol.crypto.sha512" --- Pure Lua SHA-512 Implementation for portability. +--- @class noiseprotocol.crypto.sha512 local sha512 = {} -local bitn = require("vendor.bitn") +local bitn = require("bitn") local bit32 = bitn.bit32 local bit64 = bitn.bit64 @@ -11,6 +12,21 @@ local utils = require("noiseprotocol.utils") local bytes = utils.bytes local benchmark_op = utils.benchmark.benchmark_op +-- Local references for performance (avoid module table lookups in hot loops) +local bit64_add = bit64.add +local bit64_xor = bit64.xor +local bit64_band = bit64.band +local bit64_bnot = bit64.bnot +local bit64_ror = bit64.ror +local bit64_shr = bit64.shr +local bit64_new = bit64.new +local bit32_bxor = bit32.bxor +local string_char = string.char +local string_rep = string.rep +local string_byte = string.byte +local table_concat = table.concat +local floor = math.floor + -- SHA-512 uses 64-bit words, but Lua numbers are limited to 2^53-1 -- We'll work with 32-bit high/low pairs for 64-bit arithmetic @@ -114,32 +130,45 @@ local H0 = { { 0x5be0cd19, 0x137e2179 }, } +--- Initialize an 80-element message schedule array with zeros (64-bit values) +--- @return Int64HighLow[] array Initialized array +local function create_message_schedule_64() + local arr = {} + for i = 1, 80 do + arr[i] = bit64_new(0, 0) + end + return arr +end + +-- Pre-allocated message schedule array for sha512_chunk() +local chunk_W = create_message_schedule_64() + --- SHA-512 Sigma0 function --- @param x Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function Sigma0(x) - return bit64.xor(bit64.xor(bit64.ror(x, 28), bit64.ror(x, 34)), bit64.ror(x, 39)) + return bit64_xor(bit64_xor(bit64_ror(x, 28), bit64_ror(x, 34)), bit64_ror(x, 39)) end --- SHA-512 Sigma1 function --- @param x Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function Sigma1(x) - return bit64.xor(bit64.xor(bit64.ror(x, 14), bit64.ror(x, 18)), bit64.ror(x, 41)) + return bit64_xor(bit64_xor(bit64_ror(x, 14), bit64_ror(x, 18)), bit64_ror(x, 41)) end --- SHA-512 sigma0 function --- @param x Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function sigma0(x) - return bit64.xor(bit64.xor(bit64.ror(x, 1), bit64.ror(x, 8)), bit64.shr(x, 7)) + return bit64_xor(bit64_xor(bit64_ror(x, 1), bit64_ror(x, 8)), bit64_shr(x, 7)) end --- SHA-512 sigma1 function --- @param x Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function sigma1(x) - return bit64.xor(bit64.xor(bit64.ror(x, 19), bit64.ror(x, 61)), bit64.shr(x, 6)) + return bit64_xor(bit64_xor(bit64_ror(x, 19), bit64_ror(x, 61)), bit64_shr(x, 6)) end --- SHA-512 Ch function @@ -148,7 +177,7 @@ end --- @param z Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function Ch(x, y, z) - return bit64.xor(bit64.band(x, y), bit64.band(bit64.bnot(x), z)) + return bit64_xor(bit64_band(x, y), bit64_band(bit64_bnot(x), z)) end --- SHA-512 Maj function @@ -157,19 +186,20 @@ end --- @param z Int64HighLow {high, low} input --- @return Int64HighLow {high, low} result local function Maj(x, y, z) - return bit64.xor(bit64.xor(bit64.band(x, y), bit64.band(x, z)), bit64.band(y, z)) + return bit64_xor(bit64_xor(bit64_band(x, y), bit64_band(x, z)), bit64_band(y, z)) end --- SHA-512 core compression function --- @param chunk string 128-byte chunk --- @param H HashState64 Hash state (8 64-bit values) local function sha512_chunk(chunk, H) - -- Prepare message schedule W (pre-allocate full array) - local W = {} + -- Reuse pre-allocated message schedule W + local W = chunk_W -- First 16 words are the message chunk for i = 1, 16 do - W[i] = bytes.be_bytes_to_u64(chunk, (i - 1) * 8 + 1) + local val = bytes.be_bytes_to_u64(chunk, (i - 1) * 8 + 1) + W[i][1], W[i][2] = val[1], val[2] end -- Extend the first 16 words into the remaining 64 words @@ -178,7 +208,8 @@ local function sha512_chunk(chunk, H) local w2 = W[i - 2] local s0 = sigma0(w15) local s1 = sigma1(w2) - W[i] = bit64.add(bit64.add(bit64.add(W[i - 16], s0), W[i - 7]), s1) + local result = bit64_add(bit64_add(bit64_add(W[i - 16], s0), W[i - 7]), s1) + W[i][1], W[i][2] = result[1], result[2] end -- Initialize working variables @@ -189,30 +220,30 @@ local function sha512_chunk(chunk, H) local prime = K[i] local S1 = Sigma1(e) local ch = Ch(e, f, g) - local temp1 = bit64.add(bit64.add(bit64.add(bit64.add(h, S1), ch), prime), W[i]) + local temp1 = bit64_add(bit64_add(bit64_add(bit64_add(h, S1), ch), prime), W[i]) local S0 = Sigma0(a) local maj = Maj(a, b, c) - local temp2 = bit64.add(S0, maj) + local temp2 = bit64_add(S0, maj) h = g g = f f = e - e = bit64.add(d, temp1) + e = bit64_add(d, temp1) d = c c = b b = a - a = bit64.add(temp1, temp2) + a = bit64_add(temp1, temp2) end -- Add compressed chunk to current hash value - H[1] = bit64.add(H[1], a) - H[2] = bit64.add(H[2], b) - H[3] = bit64.add(H[3], c) - H[4] = bit64.add(H[4], d) - H[5] = bit64.add(H[5], e) - H[6] = bit64.add(H[6], f) - H[7] = bit64.add(H[7], g) - H[8] = bit64.add(H[8], h) + H[1] = bit64_add(H[1], a) + H[2] = bit64_add(H[2], b) + H[3] = bit64_add(H[3], c) + H[4] = bit64_add(H[4], d) + H[5] = bit64_add(H[5], e) + H[6] = bit64_add(H[6], f) + H[7] = bit64_add(H[7], g) + H[8] = bit64_add(H[8], h) end --- Compute SHA-512 hash of input data @@ -238,19 +269,19 @@ function sha512.sha512(data) local msg_len_bits = msg_len * 8 -- Append '1' bit (plus zero padding to make it a byte) - data = data .. string.char(0x80) + data = data .. string_char(0x80) -- Append zeros to make message length ≡ 896 (mod 1024) bits = 112 (mod 128) bytes local current_len = msg_len + 1 local target_len = 112 -- We want to reach 112 bytes before adding the 16-byte length local padding_len = (target_len - current_len) % 128 - data = data .. string.rep("\0", padding_len) + data = data .. string_rep("\0", padding_len) -- Append original length as 128-bit big-endian integer -- For simplicity, we only support messages < 2^64 bits - data = data .. string.rep("\0", 8) -- High 64 bits (always 0) + data = data .. string_rep("\0", 8) -- High 64 bits (always 0) -- Low 64 bits of length - local len_high = math.floor(msg_len_bits / 0x100000000) + local len_high = floor(msg_len_bits / 0x100000000) local len_low = msg_len_bits % 0x100000000 data = data .. bytes.u64_to_be_bytes({ len_high, len_low }) @@ -268,7 +299,7 @@ function sha512.sha512(data) result_bytes[i] = bytes.u64_to_be_bytes(H[i]) end - return table.concat(result_bytes) + return table_concat(result_bytes) end --- Compute SHA-512 hash and return as hex string @@ -299,19 +330,19 @@ function sha512.hmac_sha512(key, data) -- Keys shorter than blocksize are right-padded with zeros if #key < block_size then - key = key .. string.rep("\0", block_size - #key) + key = key .. string_rep("\0", block_size - #key) end -- Compute inner and outer padding (optimized with table) local ipad_bytes = {} local opad_bytes = {} for i = 1, block_size do - local byte = string.byte(key, i) - ipad_bytes[i] = string.char(bit32.bxor(byte, 0x36)) - opad_bytes[i] = string.char(bit32.bxor(byte, 0x5C)) + local byte = string_byte(key, i) + ipad_bytes[i] = string_char(bit32_bxor(byte, 0x36)) + opad_bytes[i] = string_char(bit32_bxor(byte, 0x5C)) end - local ipad = table.concat(ipad_bytes) - local opad = table.concat(opad_bytes) + local ipad = table_concat(ipad_bytes) + local opad = table_concat(opad_bytes) -- Compute HMAC = H(opad || H(ipad || data)) local inner_hash = sha512.sha512(ipad .. data) @@ -352,7 +383,7 @@ local test_vectors = { if os.getenv("INCLUDE_SLOW_TESTS") == "1" then table.insert(test_vectors, { name = "RFC 4634 Test 5 - One million 'a' characters", - input = string.rep("a", 1000000), + input = string_rep("a", 1000000), expected = "e718483d0ce769644e2e42c7bc15b4638e1f98b13b2044285632a803afa973ebde0ff244877ea60a4cb0432ce577c31beb009c5c2c49aa2e4eadb217ad8cc09b", }) end @@ -361,7 +392,7 @@ end local hmac_test_vectors = { { name = "RFC 4231 Test Case 1", - key = string.rep(string.char(0x0b), 20), + key = string_rep(string_char(0x0b), 20), data = "Hi There", expected = "87aa7cdea5ef619d4ff0b4241a1d6cb02379f4e2ce4ec2787ad0b30545e17cdedaa833b7d6b8a702038b274eaea3f4e4be9d914eeb61f1702e696c203a126854", }, @@ -373,14 +404,14 @@ local hmac_test_vectors = { }, { name = "RFC 4231 Test Case 3", - key = string.rep(string.char(0xaa), 20), - data = string.rep(string.char(0xdd), 50), + key = string_rep(string_char(0xaa), 20), + data = string_rep(string_char(0xdd), 50), expected = "fa73b0089d56a284efb0f0756c890be9b1b5dbdd8ee81a3655f83e33b2279d39bf3e848279a722c806b485a47e67c807b946a337bee8942674278859e13292fb", }, { name = "RFC 4231 Test Case 4", key = bytes.from_hex("0102030405060708090a0b0c0d0e0f10111213141516171819"), - data = string.rep(string.char(0xcd), 50), + data = string_rep(string_char(0xcd), 50), expected = "b0ba465637458c6990e5a8c5f61d4af7e576d97ff94b872de76f8050361ee3dba91ca5c11aa25eb4d679275cc5788063a5f19741120c4f2de2adebeb10a298dd", }, } @@ -500,9 +531,9 @@ end --- including hash computation and HMAC for various message sizes. function sha512.benchmark() -- Test data - local message_64 = string.rep("a", 64) - local message_1k = string.rep("a", 1024) - local message_8k = string.rep("a", 8192) + local message_64 = string_rep("a", 64) + local message_1k = string_rep("a", 1024) + local message_8k = string_rep("a", 8192) local hmac_key = "benchmark_key" print("Hash Operations:") diff --git a/src/noiseprotocol/crypto/x25519.lua b/src/noiseprotocol/crypto/x25519.lua index 70a619f..be1073b 100644 --- a/src/noiseprotocol/crypto/x25519.lua +++ b/src/noiseprotocol/crypto/x25519.lua @@ -1,8 +1,9 @@ --- @module "noiseprotocol.crypto.x25519" --- X25519 Curve25519 Elliptic Curve Diffie-Hellman Implementation for portability. +--- @class noiseprotocol.crypto.x25519 local x25519 = {} -local bit32 = require("vendor.bitn").bit32 +local bit32 = require("bitn").bit32 local utils = require("noiseprotocol.utils") local bytes = utils.bytes @@ -12,18 +13,84 @@ local benchmark_op = utils.benchmark.benchmark_op -- CURVE25519 FIELD ARITHMETIC -- ============================================================================ +-- Local references for performance (avoid global table lookups in hot loops) +local floor = math.floor +local string_byte = string.byte +local string_char = string.char +local string_rep = string.rep +local table_concat = table.concat + +--- @alias FieldElement integer[] 16-element array (indices 1-16) representing a field element +--- @alias ProductArray integer[] 31-element array (indices 1-31) for multiplication products +--- @alias ScalarArray integer[] 32-element array (indices 1-32) for scalar bytes + +--- Initialize a 16-element field element with zeros +--- @return FieldElement fe Initialized field element +local function create_field_element() + local arr = {} + for i = 1, 16 do + arr[i] = 0 + end + return arr +end + +--- Initialize a 31-element product array with zeros +--- @return ProductArray arr Initialized array +local function create_product_array() + local arr = {} + for i = 1, 31 do + arr[i] = 0 + end + return arr +end + +--- Initialize a 32-element scalar array with zeros +--- @return ScalarArray arr Initialized array +local function create_scalar_array() + local arr = {} + for i = 1, 32 do + arr[i] = 0 + end + return arr +end + +-- Pre-allocated constant for Montgomery ladder (a24 = 121665 = 0xdb41 + 1*0x10000) +-- This is (A-2)/4 where A=486662 for Curve25519 +local A24 = { 0xdb41, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } + +-- Pre-allocated product array for mul() to avoid repeated allocation +local mul_prod = create_product_array() + +-- Pre-allocated arrays for pack() to avoid repeated allocation +local pack_t = create_field_element() +local pack_m = create_field_element() + +-- Pre-allocated arrays for inv() to avoid repeated allocation +local inv_c = create_field_element() + +-- Pre-allocated arrays for scalarmult() Montgomery ladder +-- These are the most critical - 8 arrays created per DH operation +local sm_a = create_field_element() +local sm_b = create_field_element() +local sm_c = create_field_element() +local sm_d = create_field_element() +local sm_e = create_field_element() +local sm_f = create_field_element() +local sm_x = create_field_element() +local sm_clam = create_scalar_array() + --- Carry operation for 64-bit arithmetic --- @param out integer[] Array to perform carry on local function carry(out) - for i = 0, 15 do - out[i] = out[i] + 0x10000 - local c = out[i] / 0x10000 - (out[i] / 0x10000) % 1 - if i < 15 then + for i = 1, 16 do + local v = out[i] + 0x10000 + local c = floor(v * 0.0000152587890625) -- 1/0x10000 = 0.0000152587890625 + if i < 16 then out[i + 1] = out[i + 1] + c - 1 else - out[0] = out[0] + 38 * (c - 1) + out[1] = out[1] + 38 * (c - 1) end - out[i] = out[i] - c * 0x10000 + out[i] = v - c * 0x10000 end end @@ -32,7 +99,7 @@ end --- @param b integer[] Second array --- @param bit integer Bit value (0 or 1) local function swap(a, b, bit) - for i = 0, 15 do + for i = 1, 16 do a[i], b[i] = a[i] * ((bit - 1) % 2) + b[i] * bit, b[i] * ((bit - 1) % 2) + a[i] * bit end end @@ -41,39 +108,58 @@ end --- @param out integer[] Output limb array --- @param a integer[] Input byte array local function unpack(out, a) - for i = 0, 15 do - out[i] = a[2 * i] + a[2 * i + 1] * 0x100 + for i = 1, 16 do + out[i] = a[2 * i - 1] + a[2 * i] * 0x100 end - out[15] = out[15] % 0x8000 + out[16] = out[16] % 0x8000 end +-- Pre-allocated prime constant for pack() +local PRIME = { + 0xffed, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0xffff, + 0x7fff, +} + --- Pack limb array to byte array with modular reduction --- @param out integer[] Output byte array --- @param a integer[] Input limb array local function pack(out, a) - local t, m = {}, {} - for i = 0, 15 do + -- Reuse pre-allocated arrays + local t, m = pack_t, pack_m + for i = 1, 16 do t[i] = a[i] end carry(t) carry(t) carry(t) - local prime = { [0] = 0xffed, [15] = 0x7fff } - for i = 1, 14 do - prime[i] = 0xffff - end - for _ = 0, 1 do - m[0] = t[0] - prime[0] - for i = 1, 15 do - m[i] = t[i] - prime[i] - ((m[i - 1] / 0x10000 - (m[i - 1] / 0x10000) % 1) % 2) - m[i - 1] = (m[i - 1] + 0x10000) % 0x10000 + for _ = 1, 2 do + m[1] = t[1] - PRIME[1] + for i = 2, 16 do + local prev = m[i - 1] + m[i] = t[i] - PRIME[i] - (floor(prev * 0.0000152587890625) % 2) + m[i - 1] = (prev + 0x10000) % 0x10000 end - local c = (m[15] / 0x10000 - (m[15] / 0x10000) % 1) % 2 + local c = floor(m[16] * 0.0000152587890625) % 2 swap(t, m, 1 - c) end - for i = 0, 15 do - out[2 * i] = t[i] % 0x100 - out[2 * i + 1] = t[i] / 0x100 - (t[i] / 0x100) % 1 + for i = 1, 16 do + local ti = t[i] + out[2 * i - 1] = ti % 0x100 + out[2 * i] = floor(ti * 0.00390625) -- 1/256 end end @@ -82,7 +168,7 @@ end --- @param a integer[] First input array --- @param b integer[] Second input array local function add(out, a, b) - for i = 0, 15 do + for i = 1, 16 do out[i] = a[i] + b[i] end end @@ -92,7 +178,7 @@ end --- @param a integer[] First input array --- @param b integer[] Second input array local function sub(out, a, b) - for i = 0, 15 do + for i = 1, 16 do out[i] = a[i] - b[i] end end @@ -102,19 +188,23 @@ end --- @param a integer[] First input array --- @param b integer[] Second input array local function mul(out, a, b) - local prod = {} - for i = 0, 31 do + -- Reuse pre-allocated array and clear it + local prod = mul_prod + for i = 1, 31 do prod[i] = 0 end - for i = 0, 15 do - for j = 0, 15 do - prod[i + j] = prod[i + j] + a[i] * b[j] + -- Schoolbook multiplication + for i = 1, 16 do + local ai = a[i] + for j = 1, 16 do + prod[i + j - 1] = prod[i + j - 1] + ai * b[j] end end - for i = 0, 14 do + -- Reduce mod 2^255-19 (multiply high limbs by 38 and add to low) + for i = 1, 15 do prod[i] = prod[i] + 38 * prod[i + 16] end - for i = 0, 15 do + for i = 1, 16 do out[i] = prod[i] end carry(out) @@ -125,8 +215,9 @@ end --- @param out integer[] Output array --- @param a integer[] Input array local function inv(out, a) - local c = {} - for i = 0, 15 do + -- Reuse pre-allocated array + local c = inv_c + for i = 1, 16 do c[i] = a[i] end for i = 253, 0, -1 do @@ -135,7 +226,7 @@ local function inv(out, a) mul(c, c, a) end end - for i = 0, 15 do + for i = 1, 16 do out[i] = c[i] end end @@ -145,19 +236,21 @@ end --- @param scalar integer[] Input scalar --- @param point integer[] Input point local function scalarmult(out, scalar, point) - local a, b, c, d, e, f, x, clam = {}, {}, {}, {}, {}, {}, {}, {} + -- Reuse pre-allocated arrays for Montgomery ladder state + local a, b, c, d, e, f, x, clam = sm_a, sm_b, sm_c, sm_d, sm_e, sm_f, sm_x, sm_clam unpack(x, point) - for i = 0, 15 do + for i = 1, 16 do a[i], b[i], c[i], d[i] = 0, x[i], 0, 0 end - a[0], d[0] = 1, 1 - for i = 0, 30 do + a[1], d[1] = 1, 1 + for i = 1, 31 do clam[i] = scalar[i] end - clam[0] = clam[0] - (clam[0] % 8) - clam[31] = scalar[31] % 64 + 64 + clam[1] = clam[1] - (clam[1] % 8) + clam[32] = scalar[32] % 64 + 64 for i = 254, 0, -1 do - local byte_idx = math.floor(i / 8) + -- Optimized bit extraction + local byte_idx = floor(i * 0.125) + 1 -- i / 8 + 1 local bit_idx = i % 8 local bit = bit32.band(bit32.rshift(clam[byte_idx], bit_idx), 1) swap(a, b, bit) @@ -174,7 +267,7 @@ local function scalarmult(out, scalar, point) sub(a, a, c) mul(b, a, a) sub(c, d, f) - mul(a, c, { [0] = 0xdb41, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }) + mul(a, c, A24) -- Use pre-allocated constant add(a, a, d) mul(c, c, a) mul(a, d, f) @@ -194,7 +287,7 @@ end local function string_to_bytes(s) local b = {} for i = 1, #s do - b[i - 1] = string.byte(s, i) + b[i] = string_byte(s, i) end return b end @@ -205,10 +298,10 @@ end --- @return string result Output string local function bytes_to_string(b, len) local result_bytes = {} - for i = 0, len - 1 do - result_bytes[i + 1] = string.char(b[i] or 0) + for i = 1, len do + result_bytes[i] = string_char(b[i] or 0) end - return table.concat(result_bytes) + return table_concat(result_bytes) end -- ============================================================================ @@ -225,9 +318,9 @@ function x25519.generate_private_key() local key_bytes = {} for i = 1, 32 do - key_bytes[i] = string.char(math.random(0, 255)) + key_bytes[i] = string_char(math.random(0, 255)) end - return table.concat(key_bytes) + return table_concat(key_bytes) end --- Derive public key from private key @@ -238,8 +331,8 @@ function x25519.derive_public_key(private_key) local sk = string_to_bytes(private_key) local pk = {} - local base = { [0] = 9 } - for i = 1, 31 do + local base = { 9 } + for i = 2, 32 do base[i] = 0 end @@ -280,15 +373,15 @@ end local test_vectors = { { name = "RFC 7748 Test Vector 1", - scalar = utils.bytes.from_hex("a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4"), - u_coord = utils.bytes.from_hex("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c"), - expected = utils.bytes.from_hex("c3da55379de9c6908e94ea4df28d084f32eccf03491c71f754b4075577a28552"), + scalar = bytes.from_hex("a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4"), + u_coord = bytes.from_hex("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c"), + expected = bytes.from_hex("c3da55379de9c6908e94ea4df28d084f32eccf03491c71f754b4075577a28552"), }, { name = "RFC 7748 Test Vector 2", - scalar = utils.bytes.from_hex("4b66e9d4d1b4673c5ad22691957d6af5c11b6421e0ea01d42ca4169e7918ba0d"), - u_coord = utils.bytes.from_hex("e5210f12786811d3f4b7959d0538ae2c31dbe7106fc03c3efc4cd549c715a493"), - expected = utils.bytes.from_hex("95cbde9476e8907d7aade45cb4b873f88b595a68799fa152e6f8f7647aac7957"), + scalar = bytes.from_hex("4b66e9d4d1b4673c5ad22691957d6af5c11b6421e0ea01d42ca4169e7918ba0d"), + u_coord = bytes.from_hex("e5210f12786811d3f4b7959d0538ae2c31dbe7106fc03c3efc4cd549c715a493"), + expected = bytes.from_hex("95cbde9476e8907d7aade45cb4b873f88b595a68799fa152e6f8f7647aac7957"), }, } @@ -320,10 +413,10 @@ function x25519.selftest() local result_hex = "" local expected_hex = "" for j = 1, #result do - result_hex = result_hex .. string.format("%02x", string.byte(result, j)) + result_hex = result_hex .. string.format("%02x", string_byte(result, j)) end for j = 1, #test.expected do - expected_hex = expected_hex .. string.format("%02x", string.byte(test.expected, j)) + expected_hex = expected_hex .. string.format("%02x", string_byte(test.expected, j)) end print(" Expected: " .. expected_hex) print(" Got: " .. result_hex) @@ -418,7 +511,7 @@ function x25519.selftest() -- Test 5: Edge case - all zero input (should not fail) total = total + 1 success, err = pcall(function() - local zero_key = string.rep("\0", 32) + local zero_key = string_rep("\0", 32) local priv, _pub = x25519.generate_keypair() -- This should not crash, though result may be predictable diff --git a/src/noiseprotocol/crypto/x448.lua b/src/noiseprotocol/crypto/x448.lua index c2b0b54..6f343f9 100644 --- a/src/noiseprotocol/crypto/x448.lua +++ b/src/noiseprotocol/crypto/x448.lua @@ -9,9 +9,10 @@ --- - Field arithmetic modulo p = 2^448 - 2^224 - 1 --- - Scalar multiplication on Curve448 --- - Key generation and Diffie-Hellman operations +--- @class noiseprotocol.crypto.x448 local x448 = {} -local bitn = require("vendor.bitn") +local bitn = require("bitn") local band = bitn.bit32.band local bor = bitn.bit32.bor local bxor = bitn.bit32.bxor @@ -23,6 +24,8 @@ local benchmark_op = utils.benchmark.benchmark_op local floor = math.floor local char = string.char local byte = string.byte +local string_rep = string.rep +local table_concat = table.concat -- Constants for X448 implementation -- Field prime p = 2^448 - 2^224 - 1 (Goldilocks prime) @@ -331,7 +334,7 @@ local function cswap(swap, a, b) end --- Convert bytes to field element (little-endian) ---- @param bytes string 56-byte string +--- @param b string 56-byte string --- @return table fe Field element local function fe_frombytes(b) local r = fe_zero() @@ -356,7 +359,7 @@ local function fe_tobytes(a) b[i] = char(band(t[i] or 0, 0xFF)) end - return table.concat(b) + return table_concat(b) end --- X448 scalar multiplication @@ -456,7 +459,7 @@ function x448.derive_public_key(private_key) assert(#private_key == 56, "Private key must be exactly 56 bytes") -- Base point for X448 (u = 5) - local base = char(5) .. string.rep(char(0), 55) + local base = char(5) .. string_rep(char(0), 55) return x448_scalarmult(private_key, base) end @@ -627,7 +630,7 @@ function x448.selftest() ok = pcall(function() -- Test with all-zero public key local private_key = x448.generate_private_key() - local zero_public = string.rep(char(0), 56) + local zero_public = string_rep(char(0), 56) local shared = x448.diffie_hellman(private_key, zero_public) assert(#shared == 56, "Should handle zero public key") end) diff --git a/src/noiseprotocol/init.lua b/src/noiseprotocol/init.lua index 52c69c9..72fef3d 100644 --- a/src/noiseprotocol/init.lua +++ b/src/noiseprotocol/init.lua @@ -18,6 +18,7 @@ --- static_key = my_static_key --- }) --- ... +--- @class noiseprotocol local noiseprotocol = {} local crypto = require("noiseprotocol.crypto") @@ -46,6 +47,15 @@ end -- PROTOCOL NAME PARSING -- ============================================================================ +--- Parsed protocol name components +--- @class ParsedProtocolName +--- @field pattern string Base handshake pattern (e.g., "XX", "IK", "NN") +--- @field modifiers string[] List of modifier strings (e.g., {"psk0", "psk2"}) +--- @field dh string Diffie-Hellman function name (e.g., "25519", "448") +--- @field cipher string Cipher name (e.g., "ChaChaPoly", "AESGCM") +--- @field hash string Hash function name (e.g., "SHA256", "BLAKE2s") +--- @field full_name string Original full protocol name + --- Parse pattern and modifiers from the pattern portion --- @param pattern_str string Pattern with modifiers (e.g. "NNpsk0+psk2") --- @return string pattern Base pattern (e.g. "NN") @@ -73,7 +83,7 @@ end --- Parse a Noise protocol name into its components --- @param protocol_name string Full protocol name (e.g. "Noise_NNpsk0+psk2_25519_AESGCM_SHA256") ---- @return table parsed Components: pattern, modifiers, dh, cipher, hash +--- @return ParsedProtocolName parsed Parsed protocol components local function parse_protocol_name(protocol_name) -- Protocol name format: Noise_PATTERNmodifiers_DH_CIPHER_HASH local prefix, pattern_with_modifiers, dh, cipher, hash = @@ -209,9 +219,8 @@ local function make_chachapoly_nonce(n) -- ChaCha20Poly1305 uses little-endian format: 4 zero bytes + 64-bit counter assert(n <= MAX_NONCE, "Nonce overflow") local nonce = string.rep("\0", 4) -- 4 zero bytes padding - - -- Little-endian 64-bit counter - for _ = 0, 7 do + -- Little-endian 64-bit counter (8 bytes) + for _ = 1, 8 do nonce = nonce .. string.char(n % 256) n = math.floor(n / 256) end @@ -624,6 +633,7 @@ end --- @param dh_output string Diffie-Hellman shared secret function SymmetricState:mix_key_and_hash(dh_output) local temp_h, temp_k + --- @type string, string, string self.ck, temp_h, temp_k = self.cipher_suite.hash.hkdf(self.ck, dh_output, 3) self:mix_hash(temp_h) -- Truncate temp_k if needed diff --git a/src/noiseprotocol/openssl_wrapper.lua b/src/noiseprotocol/openssl_wrapper.lua index b91d3d6..904b988 100644 --- a/src/noiseprotocol/openssl_wrapper.lua +++ b/src/noiseprotocol/openssl_wrapper.lua @@ -17,6 +17,7 @@ --- --- Note: X25519 and X448 currently use native implementations only as they are --- not currently supported by lua-openssl. +--- @class noiseprotocol.openssl_wrapper local openssl_wrapper = {} --- OpenSSL Feature Enum diff --git a/src/noiseprotocol/utils/benchmark.lua b/src/noiseprotocol/utils/benchmark.lua index cabb7fa..2388f60 100644 --- a/src/noiseprotocol/utils/benchmark.lua +++ b/src/noiseprotocol/utils/benchmark.lua @@ -1,5 +1,6 @@ --- @module "noiseprotocol.utils.benchmark" --- Common benchmarking utilities for performance testing +--- @class noiseprotocol.utils.benchmark local benchmark = {} --- Run a benchmarked operation with warmup and timing diff --git a/src/noiseprotocol/utils/bytes.lua b/src/noiseprotocol/utils/bytes.lua index a8cd9c6..abfe2aa 100644 --- a/src/noiseprotocol/utils/bytes.lua +++ b/src/noiseprotocol/utils/bytes.lua @@ -1,8 +1,11 @@ --- @module "noiseprotocol.utils.bytes" --- Byte manipulation and conversion utilities +--- @class noiseprotocol.utils.bytes local bytes = {} -local bit32 = require("vendor.bitn").bit32 +local bitn = require("bitn") +local bit32 = bitn.bit32 +local bit64 = bitn.bit64 --- Convert binary string to hexadecimal string --- @param str string Binary string @@ -39,7 +42,7 @@ function bytes.u32_to_be_bytes(n) end --- Convert 64-bit value to 8 bytes (big-endian) ---- @param x Int64HighLow|table {high, low} 64-bit value +--- @param x Int64HighLow {high, low} 64-bit value --- @return string bytes 8-byte string in big-endian order function bytes.u64_to_be_bytes(x) local high, low = x[1], x[2] @@ -47,7 +50,7 @@ function bytes.u64_to_be_bytes(x) end --- Convert 64-bit value to 8 bytes (little-endian) ---- @param x Int64HighLow|table|integer {high, low} 64-bit value or simple integer +--- @param x Int64HighLow|integer {high, low} 64-bit value or simple integer --- @return string bytes 8-byte string in little-endian order function bytes.u64_to_le_bytes(x) -- Handle simple integer case (< 2^53) @@ -313,7 +316,7 @@ function bytes.selftest() { name = "u64 LE - basic table", test = function() - local n = { 0x12345678, 0x9ABCDEF0 } + local n = bit64.new(0x12345678, 0x9ABCDEF0) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(bytes_str, 1, 8) @@ -344,7 +347,7 @@ function bytes.selftest() { name = "u64 LE - zero", test = function() - local n = { 0, 0 } + local n = bit64.new(0, 0) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) return back[1] == 0 and back[2] == 0 and bytes_str == string.rep(string.char(0), 8) @@ -353,7 +356,7 @@ function bytes.selftest() { name = "u64 LE - max value", test = function() - local n = { 0xFFFFFFFF, 0xFFFFFFFF } + local n = bit64.new(0xFFFFFFFF, 0xFFFFFFFF) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) return back[1] == 0xFFFFFFFF and back[2] == 0xFFFFFFFF and bytes_str == string.rep(string.char(0xFF), 8) @@ -362,7 +365,7 @@ function bytes.selftest() { name = "u64 LE - high word only", test = function() - local n = { 0x12345678, 0 } + local n = bit64.new(0x12345678, 0) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) return back[1] == 0x12345678 and back[2] == 0 @@ -371,7 +374,7 @@ function bytes.selftest() { name = "u64 LE - low word only", test = function() - local n = { 0, 0x12345678 } + local n = bit64.new(0, 0x12345678) local bytes_str = bytes.u64_to_le_bytes(n) local back = bytes.le_bytes_to_u64(bytes_str) return back[1] == 0 and back[2] == 0x12345678 @@ -380,7 +383,7 @@ function bytes.selftest() { name = "u64 LE - with offset", test = function() - local data = "XXX" .. bytes.u64_to_le_bytes({ 0x12345678, 0x9ABCDEF0 }) .. "YYY" + local data = "XXX" .. bytes.u64_to_le_bytes(bit64.new(0x12345678, 0x9ABCDEF0)) .. "YYY" local n = bytes.le_bytes_to_u64(data, 4) return n[1] == 0x12345678 and n[2] == 0x9ABCDEF0 end, @@ -388,7 +391,7 @@ function bytes.selftest() { name = "u64 BE - basic", test = function() - local n = { 0x12345678, 0x9ABCDEF0 } + local n = bit64.new(0x12345678, 0x9ABCDEF0) local bytes_str = bytes.u64_to_be_bytes(n) local back = bytes.be_bytes_to_u64(bytes_str) local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(bytes_str, 1, 8) @@ -407,7 +410,7 @@ function bytes.selftest() { name = "u64 BE - zero", test = function() - local n = { 0, 0 } + local n = bit64.new(0, 0) local bytes_str = bytes.u64_to_be_bytes(n) local back = bytes.be_bytes_to_u64(bytes_str) return back[1] == 0 and back[2] == 0 and bytes_str == string.rep(string.char(0), 8) @@ -416,7 +419,7 @@ function bytes.selftest() { name = "u64 BE - with offset", test = function() - local data = "XXX" .. bytes.u64_to_be_bytes({ 0x12345678, 0x9ABCDEF0 }) .. "YYY" + local data = "XXX" .. bytes.u64_to_be_bytes(bit64.new(0x12345678, 0x9ABCDEF0)) .. "YYY" local n = bytes.be_bytes_to_u64(data, 4) return n[1] == 0x12345678 and n[2] == 0x9ABCDEF0 end, @@ -602,35 +605,35 @@ function bytes.selftest() name = "u32 LE - insufficient bytes", test = function() local ok, err = pcall(bytes.le_bytes_to_u32, "XX") - return not ok and err:match("Insufficient bytes") + return not ok and type(err) == "string" and err:match("Insufficient bytes") end, }, { name = "u32 BE - insufficient bytes", test = function() local ok, err = pcall(bytes.be_bytes_to_u32, "XX") - return not ok and err:match("Insufficient bytes") + return not ok and type(err) == "string" and err:match("Insufficient bytes") end, }, { name = "u64 LE - insufficient bytes", test = function() local ok, err = pcall(bytes.le_bytes_to_u64, "XXXXXX") - return not ok and err:match("Insufficient bytes") + return not ok and type(err) == "string" and err:match("Insufficient bytes") end, }, { name = "u64 BE - insufficient bytes", test = function() local ok, err = pcall(bytes.be_bytes_to_u64, "XXXXXX") - return not ok and err:match("Insufficient bytes") + return not ok and type(err) == "string" and err:match("Insufficient bytes") end, }, { name = "xor - length mismatch", test = function() local ok, err = pcall(bytes.xor_bytes, "abc", "abcd") - return not ok and err:match("same length") + return not ok and type(err) == "string" and err:match("same length") end, }, } diff --git a/src/noiseprotocol/utils/init.lua b/src/noiseprotocol/utils/init.lua index 43ab127..cbb396c 100644 --- a/src/noiseprotocol/utils/init.lua +++ b/src/noiseprotocol/utils/init.lua @@ -1,7 +1,10 @@ --- @module "noiseprotocol.utils" --- Common utility functions for the Noise Protocol Framework +--- @class noiseprotocol.utils local utils = { + --- @type noiseprotocol.utils.bytes bytes = require("noiseprotocol.utils.bytes"), + --- @type noiseprotocol.utils.benchmark benchmark = require("noiseprotocol.utils.benchmark"), } diff --git a/vendor/bitn.lua b/vendor/bitn.lua index f79e11d..f83484d 100644 --- a/vendor/bitn.lua +++ b/vendor/bitn.lua @@ -1,110 +1,396 @@ do local _ENV = _ENV -package.preload[ "bitn.bit16" ] = function( ... ) local arg = _G.arg; ---- @module "bitn.bit16" ---- Pure Lua 16-bit bitwise operations library. ---- This module provides a complete, version-agnostic implementation of 16-bit ---- bitwise operations that works across Lua 5.1, 5.2, 5.3, 5.4, and LuaJIT ---- without depending on any built-in bit libraries. ---- @class bit16 -local bit16 = {} +package.preload[ "bitn._compat" ] = function( ... ) local arg = _G.arg; +--- @diagnostic disable: duplicate-set-field +--- @module "bitn._compat" +--- Internal compatibility layer for bitwise operations. +--- Provides feature detection and optimized primitives for use by bit16/bit32/bit64. +--- @class bitn._compat +local _compat = {} --- 16-bit mask constant -local MASK16 = 0xFFFF +-------------------------------------------------------------------------------- +-- Helper functions (needed by all implementations) +-------------------------------------------------------------------------------- ---- Ensure value fits in 16-bit unsigned integer. ---- @param n number Input value ---- @return integer result 16-bit unsigned integer (0 to 0xFFFF) -function bit16.mask(n) - return math.floor(n % 0x10000) +local math_floor = math.floor +local math_pow = math.pow or function(x, y) + return x ^ y end ---- Bitwise AND operation. ---- @param a integer First operand (16-bit) ---- @param b integer Second operand (16-bit) ---- @return integer result Result of a AND b -function bit16.band(a, b) - a = bit16.mask(a) - b = bit16.mask(b) +--- Convert signed 32-bit to unsigned (for LuaJIT which returns signed values) +--- @param n number Potentially signed 32-bit value +--- @return number Unsigned 32-bit value +local function to_unsigned(n) + if n < 0 then + return n + 0x100000000 + end + return n +end - local result = 0 - local bit_val = 1 +_compat.to_unsigned = to_unsigned - for _ = 0, 15 do - if (a % 2 == 1) and (b % 2 == 1) then - result = result + bit_val +-- Constants +local MASK32 = 0xFFFFFFFF + +-------------------------------------------------------------------------------- +-- Implementation 1: Native operators (Lua 5.3+) +-------------------------------------------------------------------------------- + +local ok, result = pcall(load, "return function(a,b) return a & b end") +if ok and result then + local fn = result() + if fn then + -- Native operators available - define all functions using them + local native_band = fn + local native_bor = assert(load("return function(a,b) return a | b end"))() + local native_bxor = assert(load("return function(a,b) return a ~ b end"))() + local native_bnot = assert(load("return function(a) return ~a end"))() + local native_lshift = assert(load("return function(a,n) return a << n end"))() + local native_rshift = assert(load("return function(a,n) return a >> n end"))() + + _compat.has_native_ops = true + _compat.has_bit_lib = false + _compat.is_luajit = false + + function _compat.impl_name() + return "native operators (Lua 5.3+)" end - a = math.floor(a / 2) - b = math.floor(b / 2) - bit_val = bit_val * 2 - if a == 0 and b == 0 then - break + function _compat.band(a, b) + return native_band(a, b) + end + + function _compat.bor(a, b) + return native_bor(a, b) + end + + function _compat.bxor(a, b) + return native_bxor(a, b) + end + + function _compat.bnot(a) + return native_band(native_bnot(a), MASK32) + end + + function _compat.lshift(a, n) + if n >= 32 then + return 0 + end + return native_band(native_lshift(a, n), MASK32) + end + + function _compat.rshift(a, n) + if n >= 32 then + return 0 + end + return native_rshift(native_band(a, MASK32), n) + end + + function _compat.arshift(a, n) + a = native_band(a, MASK32) + local is_negative = a >= 0x80000000 + if n >= 32 then + return is_negative and MASK32 or 0 + end + local r = native_rshift(a, n) + if is_negative then + local fill_mask = native_lshift(MASK32, 32 - n) + r = native_bor(r, native_band(fill_mask, MASK32)) + end + return native_band(r, MASK32) end + + return _compat end +end + +-------------------------------------------------------------------------------- +-- Implementation 2: Bit library (LuaJIT or Lua 5.2) +-------------------------------------------------------------------------------- - return result +local bit_lib +local is_luajit = false + +-- Try LuaJIT's bit library first +ok, result = pcall(require, "bit") +if ok and result then + bit_lib = result + is_luajit = true +else + -- Try Lua 5.2's bit32 library (use rawget to avoid recursion with our module name) + bit_lib = rawget(_G, "bit32") end ---- Bitwise OR operation. ---- @param a integer First operand (16-bit) ---- @param b integer Second operand (16-bit) ---- @return integer result Result of a OR b -function bit16.bor(a, b) - a = bit16.mask(a) - b = bit16.mask(b) +if bit_lib then + -- Bit library available - define all functions using it + local bit_band = assert(bit_lib.band) + local bit_bor = assert(bit_lib.bor) + local bit_bxor = assert(bit_lib.bxor) + local bit_bnot = assert(bit_lib.bnot) + local bit_lshift = assert(bit_lib.lshift) + local bit_rshift = assert(bit_lib.rshift) + local bit_arshift = assert(bit_lib.arshift) + + _compat.has_native_ops = false + _compat.has_bit_lib = true + _compat.is_luajit = is_luajit + + function _compat.impl_name() + return "bit library" + end + + if is_luajit then + -- LuaJIT returns signed integers, need to convert to unsigned + function _compat.band(a, b) + return to_unsigned(bit_band(a, b)) + end + + function _compat.bor(a, b) + return to_unsigned(bit_bor(a, b)) + end + + function _compat.bxor(a, b) + return to_unsigned(bit_bxor(a, b)) + end + + function _compat.bnot(a) + return to_unsigned(bit_bnot(a)) + end + + function _compat.lshift(a, n) + if n >= 32 then + return 0 + end + return to_unsigned(bit_lshift(a, n)) + end - local result = 0 + function _compat.rshift(a, n) + if n >= 32 then + return 0 + end + return to_unsigned(bit_rshift(a, n)) + end + + function _compat.arshift(a, n) + a = to_unsigned(bit_band(a, MASK32)) + if n >= 32 then + local is_negative = a >= 0x80000000 + return is_negative and MASK32 or 0 + end + return to_unsigned(bit_arshift(a, n)) + end + else + -- Lua 5.2 bit32 library returns unsigned integers + function _compat.band(a, b) + return bit_band(a, b) + end + + function _compat.bor(a, b) + return bit_bor(a, b) + end + + function _compat.bxor(a, b) + return bit_bxor(a, b) + end + + function _compat.bnot(a) + return bit_band(bit_bnot(a), MASK32) + end + + function _compat.lshift(a, n) + if n >= 32 then + return 0 + end + return bit_band(bit_lshift(a, n), MASK32) + end + + function _compat.rshift(a, n) + if n >= 32 then + return 0 + end + return bit_rshift(bit_band(a, MASK32), n) + end + + function _compat.arshift(a, n) + a = bit_band(a, MASK32) + if n >= 32 then + local is_negative = a >= 0x80000000 + return is_negative and MASK32 or 0 + end + return bit_band(bit_arshift(a, n), MASK32) + end + end + + return _compat +end + +-------------------------------------------------------------------------------- +-- Implementation 3: Pure Lua fallback +-------------------------------------------------------------------------------- + +_compat.has_native_ops = false +_compat.has_bit_lib = false +_compat.is_luajit = false + +function _compat.impl_name() + return "pure Lua" +end + +function _compat.band(a, b) + local r = 0 local bit_val = 1 + for _ = 0, 31 do + if (a % 2 == 1) and (b % 2 == 1) then + r = r + bit_val + end + a = math_floor(a / 2) + b = math_floor(b / 2) + bit_val = bit_val * 2 + if a == 0 and b == 0 then + break + end + end + return r +end - for _ = 0, 15 do +function _compat.bor(a, b) + local r = 0 + local bit_val = 1 + for _ = 0, 31 do if (a % 2 == 1) or (b % 2 == 1) then - result = result + bit_val + r = r + bit_val end - a = math.floor(a / 2) - b = math.floor(b / 2) + a = math_floor(a / 2) + b = math_floor(b / 2) bit_val = bit_val * 2 - if a == 0 and b == 0 then break end end - - return result + return r end ---- Bitwise XOR operation. ---- @param a integer First operand (16-bit) ---- @param b integer Second operand (16-bit) ---- @return integer result Result of a XOR b -function bit16.bxor(a, b) - a = bit16.mask(a) - b = bit16.mask(b) - - local result = 0 +function _compat.bxor(a, b) + local r = 0 local bit_val = 1 - - for _ = 0, 15 do + for _ = 0, 31 do if (a % 2) ~= (b % 2) then - result = result + bit_val + r = r + bit_val end - a = math.floor(a / 2) - b = math.floor(b / 2) + a = math_floor(a / 2) + b = math_floor(b / 2) bit_val = bit_val * 2 - if a == 0 and b == 0 then break end end + return r +end + +function _compat.bnot(a) + return MASK32 - (math_floor(a) % 0x100000000) +end + +function _compat.lshift(a, n) + if n >= 32 then + return 0 + end + return math_floor((a * math_pow(2, n)) % 0x100000000) +end + +function _compat.rshift(a, n) + if n >= 32 then + return 0 + end + a = math_floor(a) % 0x100000000 + return math_floor(a / math_pow(2, n)) +end + +function _compat.arshift(a, n) + a = math_floor(a) % 0x100000000 + local is_negative = a >= 0x80000000 + if n >= 32 then + return is_negative and MASK32 or 0 + end + local r = math_floor(a / math_pow(2, n)) + if is_negative then + local fill_mask = MASK32 - (math_pow(2, 32 - n) - 1) + r = _compat.bor(r, fill_mask) + end + return r +end + +return _compat +end +end + +do +local _ENV = _ENV +package.preload[ "bitn.bit16" ] = function( ... ) local arg = _G.arg; +--- @module "bitn.bit16" +--- 16-bit bitwise operations library. +--- This module provides a complete, version-agnostic implementation of 16-bit +--- bitwise operations that works across Lua 5.1, 5.2, 5.3, 5.4, and LuaJIT. +--- Uses native bit operations where available for optimal performance. +--- @class bitn.bit16 +local bit16 = {} + +local _compat = require("bitn._compat") + +-- Cache methods as locals for faster access +local compat_band = _compat.band +local compat_bor = _compat.bor +local compat_bxor = _compat.bxor +local compat_bnot = _compat.bnot +local compat_lshift = _compat.lshift +local compat_rshift = _compat.rshift +local impl_name = _compat.impl_name + +-- 16-bit mask constant +local MASK16 = 0xFFFF + +local math_floor = math.floor + +-------------------------------------------------------------------------------- +-- Core operations +-------------------------------------------------------------------------------- + +--- Ensure value fits in 16-bit unsigned integer. +--- @param n number Input value +--- @return integer result 16-bit unsigned integer (0 to 0xFFFF) +function bit16.mask(n) + return compat_band(math_floor(n), MASK16) +end + +--- Bitwise AND operation. +--- @param a integer First operand (16-bit) +--- @param b integer Second operand (16-bit) +--- @return integer result Result of a AND b +function bit16.band(a, b) + return compat_band(compat_band(a, MASK16), compat_band(b, MASK16)) +end + +--- Bitwise OR operation. +--- @param a integer First operand (16-bit) +--- @param b integer Second operand (16-bit) +--- @return integer result Result of a OR b +function bit16.bor(a, b) + return compat_band(compat_bor(a, b), MASK16) +end - return result +--- Bitwise XOR operation. +--- @param a integer First operand (16-bit) +--- @param b integer Second operand (16-bit) +--- @return integer result Result of a XOR b +function bit16.bxor(a, b) + return compat_band(compat_bxor(a, b), MASK16) end --- Bitwise NOT operation. --- @param a integer Operand (16-bit) --- @return integer result Result of NOT a function bit16.bnot(a) - return bit16.mask(MASK16 - bit16.mask(a)) + return compat_band(compat_bnot(a), MASK16) end --- Left shift operation. @@ -116,7 +402,7 @@ function bit16.lshift(a, n) if n >= 16 then return 0 end - return bit16.mask(bit16.mask(a) * math.pow(2, n)) + return compat_band(compat_lshift(compat_band(a, MASK16), n), MASK16) end --- Logical right shift operation (fills with 0s). @@ -125,11 +411,10 @@ end --- @return integer result Result of a >> n (logical) function bit16.rshift(a, n) assert(n >= 0, "Shift amount must be non-negative") - a = bit16.mask(a) if n >= 16 then return 0 end - return math.floor(a / math.pow(2, n)) + return compat_rshift(compat_band(a, MASK16), n) end --- Arithmetic right shift operation (sign-extending, fills with sign bit). @@ -138,27 +423,25 @@ end --- @return integer result Result of a >> n with sign extension function bit16.arshift(a, n) assert(n >= 0, "Shift amount must be non-negative") - a = bit16.mask(a) + a = compat_band(a, MASK16) -- Check if sign bit is set (bit 15) local is_negative = a >= 0x8000 if n >= 16 then - -- All bits shift out, result is all 1s if negative, all 0s if positive - return is_negative and 0xFFFF or 0 + return is_negative and MASK16 or 0 end - -- Perform logical right shift first - local result = math.floor(a / math.pow(2, n)) + -- Perform logical right shift + local result = compat_rshift(a, n) -- If original was negative, fill high bits with 1s if is_negative then - -- Create mask for high bits that need to be 1 - local fill_mask = MASK16 - (math.floor(2 ^ (16 - n)) - 1) - result = bit16.bor(result, fill_mask) + local fill_mask = compat_band(compat_lshift(MASK16, 16 - n), MASK16) + result = compat_bor(result, fill_mask) end - return result + return compat_band(result, MASK16) end --- Left rotate operation. @@ -167,8 +450,8 @@ end --- @return integer result Result of rotating x left by n positions function bit16.rol(x, n) n = n % 16 - x = bit16.mask(x) - return bit16.mask(bit16.lshift(x, n) + bit16.rshift(x, 16 - n)) + x = compat_band(x, MASK16) + return compat_band(compat_bor(compat_lshift(x, n), compat_rshift(x, 16 - n)), MASK16) end --- Right rotate operation. @@ -177,8 +460,8 @@ end --- @return integer result Result of rotating x right by n positions function bit16.ror(x, n) n = n % 16 - x = bit16.mask(x) - return bit16.mask(bit16.rshift(x, n) + bit16.lshift(x, 16 - n)) + x = compat_band(x, MASK16) + return compat_band(compat_bor(compat_rshift(x, n), compat_lshift(x, 16 - n)), MASK16) end --- 16-bit addition with overflow handling. @@ -186,27 +469,30 @@ end --- @param b integer Second operand (16-bit) --- @return integer result Result of (a + b) mod 2^16 function bit16.add(a, b) - return bit16.mask(bit16.mask(a) + bit16.mask(b)) + return compat_band(compat_band(a, MASK16) + compat_band(b, MASK16), MASK16) end -------------------------------------------------------------------------------- -- Byte conversion functions -------------------------------------------------------------------------------- +local string_char = string.char +local string_byte = string.byte + --- Convert 16-bit unsigned integer to 2 bytes (big-endian). --- @param n integer 16-bit unsigned integer --- @return string bytes 2-byte string in big-endian order function bit16.u16_to_be_bytes(n) - n = bit16.mask(n) - return string.char(math.floor(n / 256), n % 256) + n = compat_band(n, MASK16) + return string_char(math_floor(n / 256), n % 256) end --- Convert 16-bit unsigned integer to 2 bytes (little-endian). --- @param n integer 16-bit unsigned integer --- @return string bytes 2-byte string in little-endian order function bit16.u16_to_le_bytes(n) - n = bit16.mask(n) - return string.char(n % 256, math.floor(n / 256)) + n = compat_band(n, MASK16) + return string_char(n % 256, math_floor(n / 256)) end --- Convert 2 bytes to 16-bit unsigned integer (big-endian). @@ -216,7 +502,7 @@ end function bit16.be_bytes_to_u16(str, offset) offset = offset or 1 assert(#str >= offset + 1, "Insufficient bytes for u16") - local b1, b2 = string.byte(str, offset, offset + 1) + local b1, b2 = string_byte(str, offset, offset + 1) return b1 * 256 + b2 end @@ -227,7 +513,7 @@ end function bit16.le_bytes_to_u16(str, offset) offset = offset or 1 assert(#str >= offset + 1, "Insufficient bytes for u16") - local b1, b2 = string.byte(str, offset, offset + 1) + local b1, b2 = string_byte(str, offset, offset + 1) return b1 + b2 * 256 end @@ -242,6 +528,7 @@ local unpack_fn = unpack or table.unpack --- @return boolean result True if all tests pass, false otherwise function bit16.selftest() print("Running 16-bit operations test vectors...") + print(string.format(" Using: %s", impl_name())) local passed = 0 local total = 0 @@ -445,6 +732,94 @@ function bit16.selftest() return passed == total end +-------------------------------------------------------------------------------- +-- Benchmarking +-------------------------------------------------------------------------------- + +local benchmark_op = require("bitn.utils.benchmark").benchmark_op + +--- Run performance benchmarks for 16-bit operations. +function bit16.benchmark() + local iterations = 100000 + + print("16-bit Bitwise Operations:") + print(string.format(" Implementation: %s", impl_name())) + + -- Test values + local a, b = 0xAAAA, 0x5555 + + benchmark_op("band", function() + bit16.band(a, b) + end, iterations) + + benchmark_op("bor", function() + bit16.bor(a, b) + end, iterations) + + benchmark_op("bxor", function() + bit16.bxor(a, b) + end, iterations) + + benchmark_op("bnot", function() + bit16.bnot(a) + end, iterations) + + print("\n16-bit Shift Operations:") + + benchmark_op("lshift", function() + bit16.lshift(a, 4) + end, iterations) + + benchmark_op("rshift", function() + bit16.rshift(a, 4) + end, iterations) + + benchmark_op("arshift", function() + bit16.arshift(0x8000, 4) + end, iterations) + + print("\n16-bit Rotate Operations:") + + benchmark_op("rol", function() + bit16.rol(a, 4) + end, iterations) + + benchmark_op("ror", function() + bit16.ror(a, 4) + end, iterations) + + print("\n16-bit Arithmetic:") + + benchmark_op("add", function() + bit16.add(a, b) + end, iterations) + + benchmark_op("mask", function() + bit16.mask(0x12345) + end, iterations) + + print("\n16-bit Byte Conversions:") + + local bytes_be = bit16.u16_to_be_bytes(0x1234) + local bytes_le = bit16.u16_to_le_bytes(0x1234) + + benchmark_op("u16_to_be_bytes", function() + bit16.u16_to_be_bytes(0x1234) + end, iterations) + + benchmark_op("u16_to_le_bytes", function() + bit16.u16_to_le_bytes(0x1234) + end, iterations) + + benchmark_op("be_bytes_to_u16", function() + bit16.be_bytes_to_u16(bytes_be) + end, iterations) + + benchmark_op("le_bytes_to_u16", function() + bit16.le_bytes_to_u16(bytes_le) + end, iterations) +end + return bit16 end end @@ -453,21 +828,39 @@ do local _ENV = _ENV package.preload[ "bitn.bit32" ] = function( ... ) local arg = _G.arg; --- @module "bitn.bit32" ---- Pure Lua 32-bit bitwise operations library. +--- 32-bit bitwise operations library. --- This module provides a complete, version-agnostic implementation of 32-bit ---- bitwise operations that works across Lua 5.1, 5.2, 5.3, 5.4, and LuaJIT ---- without depending on any built-in bit libraries. ---- @class bit32 +--- bitwise operations that works across Lua 5.1, 5.2, 5.3, 5.4, and LuaJIT. +--- Uses native bit operations where available for optimal performance. +--- @class bitn.bit32 local bit32 = {} +local _compat = require("bitn._compat") + +-- Cache methods as locals for faster access +local compat_band = _compat.band +local compat_bor = _compat.bor +local compat_bxor = _compat.bxor +local compat_bnot = _compat.bnot +local compat_lshift = _compat.lshift +local compat_rshift = _compat.rshift +local compat_arshift = _compat.arshift +local impl_name = _compat.impl_name + -- 32-bit mask constant local MASK32 = 0xFFFFFFFF +local math_floor = math.floor + +-------------------------------------------------------------------------------- +-- Core operations +-------------------------------------------------------------------------------- + --- Ensure value fits in 32-bit unsigned integer. --- @param n number Input value --- @return integer result 32-bit unsigned integer (0 to 0xFFFFFFFF) function bit32.mask(n) - return math.floor(n % 0x100000000) + return compat_band(math_floor(n), MASK32) end --- Bitwise AND operation. @@ -475,26 +868,7 @@ end --- @param b integer Second operand (32-bit) --- @return integer result Result of a AND b function bit32.band(a, b) - a = bit32.mask(a) - b = bit32.mask(b) - - local result = 0 - local bit_val = 1 - - for _ = 0, 31 do - if (a % 2 == 1) and (b % 2 == 1) then - result = result + bit_val - end - a = math.floor(a / 2) - b = math.floor(b / 2) - bit_val = bit_val * 2 - - if a == 0 and b == 0 then - break - end - end - - return result + return compat_band(compat_band(a, MASK32), compat_band(b, MASK32)) end --- Bitwise OR operation. @@ -502,26 +876,7 @@ end --- @param b integer Second operand (32-bit) --- @return integer result Result of a OR b function bit32.bor(a, b) - a = bit32.mask(a) - b = bit32.mask(b) - - local result = 0 - local bit_val = 1 - - for _ = 0, 31 do - if (a % 2 == 1) or (b % 2 == 1) then - result = result + bit_val - end - a = math.floor(a / 2) - b = math.floor(b / 2) - bit_val = bit_val * 2 - - if a == 0 and b == 0 then - break - end - end - - return result + return compat_band(compat_bor(a, b), MASK32) end --- Bitwise XOR operation. @@ -529,33 +884,14 @@ end --- @param b integer Second operand (32-bit) --- @return integer result Result of a XOR b function bit32.bxor(a, b) - a = bit32.mask(a) - b = bit32.mask(b) - - local result = 0 - local bit_val = 1 - - for _ = 0, 31 do - if (a % 2) ~= (b % 2) then - result = result + bit_val - end - a = math.floor(a / 2) - b = math.floor(b / 2) - bit_val = bit_val * 2 - - if a == 0 and b == 0 then - break - end - end - - return result + return compat_band(compat_bxor(a, b), MASK32) end --- Bitwise NOT operation. --- @param a integer Operand (32-bit) --- @return integer result Result of NOT a function bit32.bnot(a) - return bit32.mask(MASK32 - bit32.mask(a)) + return compat_band(compat_bnot(a), MASK32) end --- Left shift operation. @@ -567,7 +903,7 @@ function bit32.lshift(a, n) if n >= 32 then return 0 end - return bit32.mask(bit32.mask(a) * math.pow(2, n)) + return compat_band(compat_lshift(compat_band(a, MASK32), n), MASK32) end --- Logical right shift operation (fills with 0s). @@ -576,11 +912,10 @@ end --- @return integer result Result of a >> n (logical) function bit32.rshift(a, n) assert(n >= 0, "Shift amount must be non-negative") - a = bit32.mask(a) if n >= 32 then return 0 end - return math.floor(a / math.pow(2, n)) + return compat_rshift(compat_band(a, MASK32), n) end --- Arithmetic right shift operation (sign-extending, fills with sign bit). @@ -589,27 +924,7 @@ end --- @return integer result Result of a >> n with sign extension function bit32.arshift(a, n) assert(n >= 0, "Shift amount must be non-negative") - a = bit32.mask(a) - - -- Check if sign bit is set (bit 31) - local is_negative = a >= 0x80000000 - - if n >= 32 then - -- All bits shift out, result is all 1s if negative, all 0s if positive - return is_negative and 0xFFFFFFFF or 0 - end - - -- Perform logical right shift first - local result = math.floor(a / math.pow(2, n)) - - -- If original was negative, fill high bits with 1s - if is_negative then - -- Create mask for high bits that need to be 1 - local fill_mask = MASK32 - (math.pow(2, 32 - n) - 1) - result = bit32.bor(result, fill_mask) - end - - return result + return compat_arshift(a, n) end --- Left rotate operation. @@ -618,8 +933,8 @@ end --- @return integer result Result of rotating x left by n positions function bit32.rol(x, n) n = n % 32 - x = bit32.mask(x) - return bit32.mask(bit32.lshift(x, n) + bit32.rshift(x, 32 - n)) + x = compat_band(x, MASK32) + return compat_band(compat_bor(compat_lshift(x, n), compat_rshift(x, 32 - n)), MASK32) end --- Right rotate operation. @@ -628,8 +943,8 @@ end --- @return integer result Result of rotating x right by n positions function bit32.ror(x, n) n = n % 32 - x = bit32.mask(x) - return bit32.mask(bit32.rshift(x, n) + bit32.lshift(x, 32 - n)) + x = compat_band(x, MASK32) + return compat_band(compat_bor(compat_rshift(x, n), compat_lshift(x, 32 - n)), MASK32) end --- 32-bit addition with overflow handling. @@ -637,27 +952,30 @@ end --- @param b integer Second operand (32-bit) --- @return integer result Result of (a + b) mod 2^32 function bit32.add(a, b) - return bit32.mask(bit32.mask(a) + bit32.mask(b)) + return compat_band(compat_band(a, MASK32) + compat_band(b, MASK32), MASK32) end -------------------------------------------------------------------------------- -- Byte conversion functions -------------------------------------------------------------------------------- +local string_char = string.char +local string_byte = string.byte + --- Convert 32-bit unsigned integer to 4 bytes (big-endian). --- @param n integer 32-bit unsigned integer --- @return string bytes 4-byte string in big-endian order function bit32.u32_to_be_bytes(n) - n = bit32.mask(n) - return string.char(math.floor(n / 16777216) % 256, math.floor(n / 65536) % 256, math.floor(n / 256) % 256, n % 256) + n = compat_band(n, MASK32) + return string_char(math_floor(n / 16777216) % 256, math_floor(n / 65536) % 256, math_floor(n / 256) % 256, n % 256) end --- Convert 32-bit unsigned integer to 4 bytes (little-endian). --- @param n integer 32-bit unsigned integer --- @return string bytes 4-byte string in little-endian order function bit32.u32_to_le_bytes(n) - n = bit32.mask(n) - return string.char(n % 256, math.floor(n / 256) % 256, math.floor(n / 65536) % 256, math.floor(n / 16777216) % 256) + n = compat_band(n, MASK32) + return string_char(n % 256, math_floor(n / 256) % 256, math_floor(n / 65536) % 256, math_floor(n / 16777216) % 256) end --- Convert 4 bytes to 32-bit unsigned integer (big-endian). @@ -667,7 +985,7 @@ end function bit32.be_bytes_to_u32(str, offset) offset = offset or 1 assert(#str >= offset + 3, "Insufficient bytes for u32") - local b1, b2, b3, b4 = string.byte(str, offset, offset + 3) + local b1, b2, b3, b4 = string_byte(str, offset, offset + 3) return b1 * 16777216 + b2 * 65536 + b3 * 256 + b4 end @@ -678,7 +996,7 @@ end function bit32.le_bytes_to_u32(str, offset) offset = offset or 1 assert(#str >= offset + 3, "Insufficient bytes for u32") - local b1, b2, b3, b4 = string.byte(str, offset, offset + 3) + local b1, b2, b3, b4 = string_byte(str, offset, offset + 3) return b1 + b2 * 256 + b3 * 65536 + b4 * 16777216 end @@ -693,6 +1011,7 @@ local unpack_fn = unpack or table.unpack --- @return boolean result True if all tests pass, false otherwise function bit32.selftest() print("Running 32-bit operations test vectors...") + print(string.format(" Using: %s", impl_name())) local passed = 0 local total = 0 @@ -948,6 +1267,94 @@ function bit32.selftest() return passed == total end +-------------------------------------------------------------------------------- +-- Benchmarking +-------------------------------------------------------------------------------- + +local benchmark_op = require("bitn.utils.benchmark").benchmark_op + +--- Run performance benchmarks for 32-bit operations. +function bit32.benchmark() + local iterations = 100000 + + print("32-bit Bitwise Operations:") + print(string.format(" Implementation: %s", impl_name())) + + -- Test values + local a, b = 0xAAAAAAAA, 0x55555555 + + benchmark_op("band", function() + bit32.band(a, b) + end, iterations) + + benchmark_op("bor", function() + bit32.bor(a, b) + end, iterations) + + benchmark_op("bxor", function() + bit32.bxor(a, b) + end, iterations) + + benchmark_op("bnot", function() + bit32.bnot(a) + end, iterations) + + print("\n32-bit Shift Operations:") + + benchmark_op("lshift", function() + bit32.lshift(a, 8) + end, iterations) + + benchmark_op("rshift", function() + bit32.rshift(a, 8) + end, iterations) + + benchmark_op("arshift", function() + bit32.arshift(0x80000000, 8) + end, iterations) + + print("\n32-bit Rotate Operations:") + + benchmark_op("rol", function() + bit32.rol(a, 8) + end, iterations) + + benchmark_op("ror", function() + bit32.ror(a, 8) + end, iterations) + + print("\n32-bit Arithmetic:") + + benchmark_op("add", function() + bit32.add(a, b) + end, iterations) + + benchmark_op("mask", function() + bit32.mask(0x123456789) + end, iterations) + + print("\n32-bit Byte Conversions:") + + local bytes_be = bit32.u32_to_be_bytes(0x12345678) + local bytes_le = bit32.u32_to_le_bytes(0x12345678) + + benchmark_op("u32_to_be_bytes", function() + bit32.u32_to_be_bytes(0x12345678) + end, iterations) + + benchmark_op("u32_to_le_bytes", function() + bit32.u32_to_le_bytes(0x12345678) + end, iterations) + + benchmark_op("be_bytes_to_u32", function() + bit32.be_bytes_to_u32(bytes_be) + end, iterations) + + benchmark_op("le_bytes_to_u32", function() + bit32.le_bytes_to_u32(bytes_le) + end, iterations) +end + return bit32 end end @@ -956,15 +1363,30 @@ do local _ENV = _ENV package.preload[ "bitn.bit64" ] = function( ... ) local arg = _G.arg; --- @module "bitn.bit64" ---- Pure Lua 64-bit bitwise operations library. +--- 64-bit bitwise operations library. --- This module provides 64-bit bitwise operations using {high, low} pairs, --- where high is the upper 32 bits and low is the lower 32 bits. ---- Works across Lua 5.1, 5.2, 5.3, 5.4, and LuaJIT without depending on ---- any built-in bit libraries. ---- @class bit64 +--- Works across Lua 5.1, 5.2, 5.3, 5.4, and LuaJIT. +--- Uses native bit operations where available for optimal performance. +--- @class bitn.bit64 local bit64 = {} local bit32 = require("bitn.bit32") +local _compat = require("bitn._compat") +local impl_name = _compat.impl_name + +-- Cache bit32 methods as locals for faster access +local bit32_band = bit32.band +local bit32_bor = bit32.bor +local bit32_bxor = bit32.bxor +local bit32_bnot = bit32.bnot +local bit32_lshift = bit32.lshift +local bit32_rshift = bit32.rshift +local bit32_arshift = bit32.arshift +local bit32_u32_to_be_bytes = bit32.u32_to_be_bytes +local bit32_u32_to_le_bytes = bit32.u32_to_le_bytes +local bit32_be_bytes_to_u32 = bit32.be_bytes_to_u32 +local bit32_le_bytes_to_u32 = bit32.le_bytes_to_u32 -- Private metatable for Int64 type identification local Int64Meta = { __name = "Int64" } @@ -1000,7 +1422,7 @@ end --- @param b Int64HighLow Second operand {high, low} --- @return Int64HighLow result {high, low} AND result function bit64.band(a, b) - return bit64.new(bit32.band(a[1], b[1]), bit32.band(a[2], b[2])) + return bit64.new(bit32_band(a[1], b[1]), bit32_band(a[2], b[2])) end --- Bitwise OR operation. @@ -1008,7 +1430,7 @@ end --- @param b Int64HighLow Second operand {high, low} --- @return Int64HighLow result {high, low} OR result function bit64.bor(a, b) - return bit64.new(bit32.bor(a[1], b[1]), bit32.bor(a[2], b[2])) + return bit64.new(bit32_bor(a[1], b[1]), bit32_bor(a[2], b[2])) end --- Bitwise XOR operation. @@ -1016,14 +1438,14 @@ end --- @param b Int64HighLow Second operand {high, low} --- @return Int64HighLow result {high, low} XOR result function bit64.bxor(a, b) - return bit64.new(bit32.bxor(a[1], b[1]), bit32.bxor(a[2], b[2])) + return bit64.new(bit32_bxor(a[1], b[1]), bit32_bxor(a[2], b[2])) end --- Bitwise NOT operation. --- @param a Int64HighLow Operand {high, low} --- @return Int64HighLow result {high, low} NOT result function bit64.bnot(a) - return bit64.new(bit32.bnot(a[1]), bit32.bnot(a[2])) + return bit64.new(bit32_bnot(a[1]), bit32_bnot(a[2])) end -------------------------------------------------------------------------------- @@ -1041,11 +1463,11 @@ function bit64.lshift(x, n) return bit64.new(0, 0) elseif n >= 32 then -- Shift by 32 or more: low becomes 0, high gets bits from low - return bit64.new(bit32.lshift(x[2], n - 32), 0) + return bit64.new(bit32_lshift(x[2], n - 32), 0) else -- Shift by less than 32 - local new_high = bit32.bor(bit32.lshift(x[1], n), bit32.rshift(x[2], 32 - n)) - local new_low = bit32.lshift(x[2], n) + local new_high = bit32_bor(bit32_lshift(x[1], n), bit32_rshift(x[2], 32 - n)) + local new_low = bit32_lshift(x[2], n) return bit64.new(new_high, new_low) end end @@ -1061,11 +1483,11 @@ function bit64.rshift(x, n) return bit64.new(0, 0) elseif n >= 32 then -- Shift by 32 or more: high becomes 0, low gets bits from high - return bit64.new(0, bit32.rshift(x[1], n - 32)) + return bit64.new(0, bit32_rshift(x[1], n - 32)) else -- Shift by less than 32 - local new_low = bit32.bor(bit32.rshift(x[2], n), bit32.lshift(x[1], 32 - n)) - local new_high = bit32.rshift(x[1], n) + local new_low = bit32_bor(bit32_rshift(x[2], n), bit32_lshift(x[1], 32 - n)) + local new_high = bit32_rshift(x[1], n) return bit64.new(new_high, new_low) end end @@ -1080,7 +1502,7 @@ function bit64.arshift(x, n) end -- Check sign bit (bit 31 of high word) - local is_negative = bit32.band(x[1], 0x80000000) ~= 0 + local is_negative = bit32_band(x[1], 0x80000000) ~= 0 if n >= 64 then -- All bits shift out, result is all 1s if negative, all 0s if positive @@ -1091,13 +1513,13 @@ function bit64.arshift(x, n) end elseif n >= 32 then -- High word shifts into low, high fills with sign - local new_low = bit32.arshift(x[1], n - 32) + local new_low = bit32_arshift(x[1], n - 32) local new_high = is_negative and 0xFFFFFFFF or 0 return bit64.new(new_high, new_low) else -- Shift by less than 32 - local new_low = bit32.bor(bit32.rshift(x[2], n), bit32.lshift(x[1], 32 - n)) - local new_high = bit32.arshift(x[1], n) + local new_low = bit32_bor(bit32_rshift(x[2], n), bit32_lshift(x[1], 32 - n)) + local new_high = bit32_arshift(x[1], n) return bit64.new(new_high, new_low) end end @@ -1123,14 +1545,14 @@ function bit64.rol(x, n) return bit64.new(low, high) elseif n < 32 then -- Rotate within 32-bit boundaries - local new_high = bit32.bor(bit32.lshift(high, n), bit32.rshift(low, 32 - n)) - local new_low = bit32.bor(bit32.lshift(low, n), bit32.rshift(high, 32 - n)) + local new_high = bit32_bor(bit32_lshift(high, n), bit32_rshift(low, 32 - n)) + local new_low = bit32_bor(bit32_lshift(low, n), bit32_rshift(high, 32 - n)) return bit64.new(new_high, new_low) else -- n > 32: rotate by (n - 32) after swapping n = n - 32 - local new_high = bit32.bor(bit32.lshift(low, n), bit32.rshift(high, 32 - n)) - local new_low = bit32.bor(bit32.lshift(high, n), bit32.rshift(low, 32 - n)) + local new_high = bit32_bor(bit32_lshift(low, n), bit32_rshift(high, 32 - n)) + local new_low = bit32_bor(bit32_lshift(high, n), bit32_rshift(low, 32 - n)) return bit64.new(new_high, new_low) end end @@ -1152,14 +1574,14 @@ function bit64.ror(x, n) return bit64.new(low, high) elseif n < 32 then -- Rotate within 32-bit boundaries - local new_low = bit32.bor(bit32.rshift(low, n), bit32.lshift(high, 32 - n)) - local new_high = bit32.bor(bit32.rshift(high, n), bit32.lshift(low, 32 - n)) + local new_low = bit32_bor(bit32_rshift(low, n), bit32_lshift(high, 32 - n)) + local new_high = bit32_bor(bit32_rshift(high, n), bit32_lshift(low, 32 - n)) return bit64.new(new_high, new_low) else -- n > 32: rotate by (n - 32) after swapping n = n - 32 - local new_low = bit32.bor(bit32.rshift(high, n), bit32.lshift(low, 32 - n)) - local new_high = bit32.bor(bit32.rshift(low, n), bit32.lshift(high, 32 - n)) + local new_low = bit32_bor(bit32_rshift(high, n), bit32_lshift(low, 32 - n)) + local new_high = bit32_bor(bit32_rshift(low, n), bit32_lshift(high, 32 - n)) return bit64.new(new_high, new_low) end end @@ -1196,14 +1618,14 @@ end --- @param x Int64HighLow 64-bit value {high, low} --- @return string bytes 8-byte string in big-endian order function bit64.u64_to_be_bytes(x) - return bit32.u32_to_be_bytes(x[1]) .. bit32.u32_to_be_bytes(x[2]) + return bit32_u32_to_be_bytes(x[1]) .. bit32_u32_to_be_bytes(x[2]) end --- Convert 64-bit value to 8 bytes (little-endian). --- @param x Int64HighLow 64-bit value {high, low} --- @return string bytes 8-byte string in little-endian order function bit64.u64_to_le_bytes(x) - return bit32.u32_to_le_bytes(x[2]) .. bit32.u32_to_le_bytes(x[1]) + return bit32_u32_to_le_bytes(x[2]) .. bit32_u32_to_le_bytes(x[1]) end --- Convert 8 bytes to 64-bit value (big-endian). @@ -1213,8 +1635,8 @@ end function bit64.be_bytes_to_u64(str, offset) offset = offset or 1 assert(#str >= offset + 7, "Insufficient bytes for u64") - local high = bit32.be_bytes_to_u32(str, offset) - local low = bit32.be_bytes_to_u32(str, offset + 4) + local high = bit32_be_bytes_to_u32(str, offset) + local low = bit32_be_bytes_to_u32(str, offset + 4) return bit64.new(high, low) end @@ -1225,8 +1647,8 @@ end function bit64.le_bytes_to_u64(str, offset) offset = offset or 1 assert(#str >= offset + 7, "Insufficient bytes for u64") - local low = bit32.le_bytes_to_u32(str, offset) - local high = bit32.le_bytes_to_u32(str, offset + 4) + local low = bit32_le_bytes_to_u32(str, offset) + local high = bit32_le_bytes_to_u32(str, offset + 4) return bit64.new(high, low) end @@ -1245,9 +1667,15 @@ end --- Warning: Lua numbers use 64-bit IEEE 754 doubles with 53-bit mantissa precision. --- Values exceeding 53 bits (greater than 9007199254740991) will lose precision. --- To maintain full 64-bit precision, keep values in {high, low} format. ---- @param value number|Int64HighLow The {high_32, low_32} pair (or number to pass through). +--- @param value number|integer|Int64HighLow The {high_32, low_32} pair (or number to pass through). --- @param strict? boolean If true, errors when value exceeds 53-bit precision. ---- @return number result The value as a Lua number (may lose precision for large values unless strict). +--- @return number|integer result The value as a Lua number (may lose precision for large values unless strict). +--- @overload fun(value: number, strict?: boolean): number +--- @overload fun(value: integer, strict?: boolean): integer +--- @overload fun(value: Int64HighLow, strict?: boolean): integer +--- @overload fun(value: number): number +--- @overload fun(value: integer): integer +--- @overload fun(value: Int64HighLow): integer function bit64.to_number(value, strict) if type(value) == "number" then return value @@ -1335,6 +1763,7 @@ end --- @return boolean result True if all tests pass, false otherwise function bit64.selftest() print("Running 64-bit operations test vectors...") + print(string.format(" Using: %s", impl_name())) local passed = 0 local total = 0 @@ -1884,16 +2313,192 @@ function bit64.selftest() return passed == total end +-------------------------------------------------------------------------------- +-- Benchmarking +-------------------------------------------------------------------------------- + +local benchmark_op = require("bitn.utils.benchmark").benchmark_op + +--- Run performance benchmarks for 64-bit operations. +function bit64.benchmark() + local iterations = 100000 + + print("64-bit Bitwise Operations:") + print(string.format(" Implementation: %s", impl_name())) + + -- Test values + local a = bit64.new(0xAAAAAAAA, 0x55555555) + local b = bit64.new(0x55555555, 0xAAAAAAAA) + + benchmark_op("band", function() + bit64.band(a, b) + end, iterations) + + benchmark_op("bor", function() + bit64.bor(a, b) + end, iterations) + + benchmark_op("bxor", function() + bit64.bxor(a, b) + end, iterations) + + benchmark_op("bnot", function() + bit64.bnot(a) + end, iterations) + + print("\n64-bit Shift Operations:") + + benchmark_op("lshift (small)", function() + bit64.lshift(a, 8) + end, iterations) + + benchmark_op("lshift (large)", function() + bit64.lshift(a, 40) + end, iterations) + + benchmark_op("rshift (small)", function() + bit64.rshift(a, 8) + end, iterations) + + benchmark_op("rshift (large)", function() + bit64.rshift(a, 40) + end, iterations) + + benchmark_op("arshift", function() + bit64.arshift(bit64.new(0x80000000, 0), 8) + end, iterations) + + print("\n64-bit Rotate Operations:") + + benchmark_op("rol (small)", function() + bit64.rol(a, 8) + end, iterations) + + benchmark_op("rol (large)", function() + bit64.rol(a, 40) + end, iterations) + + benchmark_op("ror (small)", function() + bit64.ror(a, 8) + end, iterations) + + benchmark_op("ror (large)", function() + bit64.ror(a, 40) + end, iterations) + + print("\n64-bit Arithmetic:") + + benchmark_op("add", function() + bit64.add(a, b) + end, iterations) + + benchmark_op("add (with carry)", function() + bit64.add(bit64.new(0, 0xFFFFFFFF), bit64.new(0, 1)) + end, iterations) + + print("\n64-bit Byte Conversions:") + + local val = bit64.new(0x12345678, 0x9ABCDEF0) + local bytes_be = bit64.u64_to_be_bytes(val) + local bytes_le = bit64.u64_to_le_bytes(val) + + benchmark_op("u64_to_be_bytes", function() + bit64.u64_to_be_bytes(val) + end, iterations) + + benchmark_op("u64_to_le_bytes", function() + bit64.u64_to_le_bytes(val) + end, iterations) + + benchmark_op("be_bytes_to_u64", function() + bit64.be_bytes_to_u64(bytes_be) + end, iterations) + + benchmark_op("le_bytes_to_u64", function() + bit64.le_bytes_to_u64(bytes_le) + end, iterations) + + print("\n64-bit Utility Functions:") + + benchmark_op("new", function() + bit64.new(0x12345678, 0x9ABCDEF0) + end, iterations) + + benchmark_op("is_int64", function() + bit64.is_int64(a) + end, iterations) + + benchmark_op("to_hex", function() + bit64.to_hex(a) + end, iterations) + + benchmark_op("to_number", function() + bit64.to_number(a) + end, iterations) + + benchmark_op("from_number", function() + bit64.from_number(12345678901234) + end, iterations) + + benchmark_op("eq", function() + bit64.eq(a, b) + end, iterations) + + benchmark_op("is_zero", function() + bit64.is_zero(a) + end, iterations) +end + return bit64 end end +do +local _ENV = _ENV +package.preload[ "bitn.utils.benchmark" ] = function( ... ) local arg = _G.arg; +--- @module "bitn.utils.benchmark" +--- Common benchmarking utilities for performance testing +--- @class bitn.utils.benchmark +local benchmark = {} + +--- Run a benchmarked operation with warmup and timing +--- @param name string Operation name for display +--- @param func function Function to benchmark +--- @param iterations? integer Number of iterations (default: 100) +--- @return number ms_per_op Milliseconds per operation +function benchmark.benchmark_op(name, func, iterations) + iterations = iterations or 100 + + -- Warmup + for _ = 1, 3 do + func() + end + + -- Actual benchmark + local start = os.clock() + for _ = 1, iterations do + func() + end + local elapsed = os.clock() - start + + local per_op = (elapsed / iterations) * 1000 -- ms + local ops_per_sec = iterations / elapsed + + print(string.format("%-30s: %8.3f ms/op, %8.1f ops/sec", name, per_op, ops_per_sec)) + + return per_op +end + +return benchmark +end +end + --- @module "bitn" ---- Pure Lua bitwise operations library. +--- Portable bitwise operations library with automatic optimization. --- This library provides standalone, version-agnostic implementations of --- bitwise operations for 16-bit, 32-bit, and 64-bit integers. It works ---- across Lua 5.1, 5.2, 5.3, 5.4, and LuaJIT without depending on any ---- built-in bit libraries. +--- across Lua 5.1, 5.2, 5.3, 5.4, and LuaJIT with zero external dependencies. +--- Automatically uses native bit operations when available for optimal performance. --- --- @usage --- local bitn = require("bitn") @@ -1910,13 +2515,16 @@ end --- --- @class bitn local bitn = { + --- @type bitn.bit16 16-bit bitwise operations bit16 = require("bitn.bit16"), + --- @type bitn.bit32 32-bit bitwise operations bit32 = require("bitn.bit32"), + --- @type bitn.bit64 64-bit bitwise operations bit64 = require("bitn.bit64"), } --- Library version (injected at build time for releases). -local VERSION = "v0.4.1" +local VERSION = "v0.5.1" --- Get the library version string. --- @return string version Version string (e.g., "v1.0.0" or "dev")