Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 73 additions & 78 deletions src/noiseprotocol/crypto/aes_gcm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ local utils = require("noiseprotocol.utils")
local bytes = utils.bytes
local benchmark_op = utils.benchmark.benchmark_op

-- Local references for performance (avoid module table lookups in hot loops)
local bit32_band = bit32.band
local bit32_bor = bit32.bor
local bit32_bxor = bit32.bxor
local bit32_lshift = bit32.lshift
local bit32_rshift = bit32.rshift
-- Local references for performance
local bit32_raw_band = bit32.raw_band
local bit32_raw_bor = bit32.raw_bor
local bit32_raw_bxor = bit32.raw_bxor
local bit32_raw_lshift = bit32.raw_lshift
local bit32_raw_rshift = bit32.raw_rshift
local string_byte = string.byte
local string_char = string.char
local string_rep = string.rep
Expand Down Expand Up @@ -302,50 +302,29 @@ local RCON = {
0x36,
}

--- @alias AESGCMWord [integer, integer, integer, integer]
--- @alias AESGCMBlock [integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer]
--- @alias AESGCMState [AESGCMWord, AESGCMWord, AESGCMWord, AESGCMWord]
--- @alias AESWord [integer, integer, integer, integer]
--- @alias AESBlock [integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer, integer]
--- @alias AESState [AESWord, AESWord, AESWord, AESWord]

--- Initialize a 16-element GCM block with zeros
--- @return AESGCMBlock block Initialized block
local function create_gcm_block()
local arr = {}
for i = 1, 16 do
arr[i] = 0
end
--- @cast arr AESGCMBlock
return arr
--- Initialize a 4-element AES word with zeros
--- @return AESWord word Initialized word
local function create_aes_word()
--- @type AESState
return { 0, 0, 0, 0 }
end

--- Initialize a 4x4 AES state array with zeros
--- @return AESGCMState state Initialized state
--- @return AESState state Initialized state
local function create_aes_state()
local state = {}
for i = 1, 4 do
state[i] = {}
for j = 1, 4 do
state[i][j] = 0
end
end
--- @cast state AESGCMState
return state
end

--- Initialize a 4-element AES word with zeros
--- @return AESGCMWord word Initialized word
local function create_aes_word()
local arr = {}
for i = 1, 4 do
arr[i] = 0
end
--- @cast arr AESGCMWord
return arr
--- @type AESState
return {
create_aes_word(),
create_aes_word(),
create_aes_word(),
create_aes_word(),
}
end

-- Pre-allocated arrays for gcm_multiply() to avoid repeated allocation
local gcm_z = create_gcm_block()
local gcm_v = create_gcm_block()

-- Pre-allocated state array for aes_encrypt_block()
local aes_state = create_aes_state()

Expand All @@ -354,28 +333,28 @@ local mix_a = create_aes_word()
local mix_b = create_aes_word()

--- XOR two 4-byte words
--- @param a AESGCMWord 4-byte array
--- @param b AESGCMWord 4-byte array
--- @return table Word 4-byte array
--- @param a AESWord 4-byte array
--- @param b AESWord 4-byte array
--- @return AESWord result 4-byte array
local function xor_words(a, b)
return {
bit32_bxor(a[1], b[1]),
bit32_bxor(a[2], b[2]),
bit32_bxor(a[3], b[3]),
bit32_bxor(a[4], b[4]),
bit32_raw_bxor(a[1], b[1]),
bit32_raw_bxor(a[2], b[2]),
bit32_raw_bxor(a[3], b[3]),
bit32_raw_bxor(a[4], b[4]),
}
end

--- Rotate word (circular left shift by 1 byte)
--- @param word AESGCMWord 4-byte array
--- @return AESGCMWord result Rotated 4-byte array
--- @param word AESWord 4-byte array
--- @return AESWord result Rotated 4-byte array
local function rot_word(word)
return { word[2], word[3], word[4], word[1] }
end

--- Apply S-box substitution to a word
--- @param word AESGCMWord 4-byte array
--- @return AESGCMWord result Substituted 4-byte array
--- @param word AESWord 4-byte array
--- @return AESWord result Substituted 4-byte array
local function sub_word(word)
local s_1 = assert(SBOX[word[1] + 1], "Invalid SBOX index " .. (word[1] + 1))
local s_2 = assert(SBOX[word[2] + 1], "Invalid SBOX index " .. (word[2] + 1))
Expand Down Expand Up @@ -407,7 +386,7 @@ local function key_expansion(key)
end

-- Convert key to words
--- @type AESGCMState
--- @type AESState
local w = {}
for i = 1, nk do
w[i] = {
Expand Down Expand Up @@ -435,27 +414,28 @@ local function key_expansion(key)
end

--- MixColumns transformation
--- @param state AESGCMState 4x4 state matrix
--- @param state AESState 4x4 state matrix
local function mix_columns(state)
-- Reuse pre-allocated arrays
local a = mix_a
local b = mix_b
for c = 1, 4 do
for i = 1, 4 do
a[i] = state[i][c]
b[i] = bit32_band(state[i][c], 0x80) ~= 0 and bit32_bxor(bit32_band(bit32_lshift(state[i][c], 1), 0xFF), 0x1B)
or bit32_band(bit32_lshift(state[i][c], 1), 0xFF)
b[i] = bit32_raw_band(state[i][c], 0x80) ~= 0
and bit32_raw_bxor(bit32_raw_band(bit32_raw_lshift(state[i][c], 1), 0xFF), 0x1B)
or bit32_raw_band(bit32_raw_lshift(state[i][c], 1), 0xFF)
end

state[1][c] = bit32_bxor(bit32_bxor(bit32_bxor(b[1], a[2]), bit32_bxor(b[2], a[3])), a[4])
state[2][c] = bit32_bxor(bit32_bxor(bit32_bxor(a[1], b[2]), bit32_bxor(a[3], b[3])), a[4])
state[3][c] = bit32_bxor(bit32_bxor(bit32_bxor(a[1], a[2]), bit32_bxor(b[3], a[4])), b[4])
state[4][c] = bit32_bxor(bit32_bxor(bit32_bxor(a[1], b[1]), bit32_bxor(a[2], a[3])), b[4])
state[1][c] = bit32_raw_bxor(bit32_raw_bxor(bit32_raw_bxor(b[1], a[2]), bit32_raw_bxor(b[2], a[3])), a[4])
state[2][c] = bit32_raw_bxor(bit32_raw_bxor(bit32_raw_bxor(a[1], b[2]), bit32_raw_bxor(a[3], b[3])), a[4])
state[3][c] = bit32_raw_bxor(bit32_raw_bxor(bit32_raw_bxor(a[1], a[2]), bit32_raw_bxor(b[3], a[4])), b[4])
state[4][c] = bit32_raw_bxor(bit32_raw_bxor(bit32_raw_bxor(a[1], b[1]), bit32_raw_bxor(a[2], a[3])), b[4])
end
end

--- SubBytes transformation
--- @param state AESGCMState 4x4 state matrix
--- @param state AESState 4x4 state matrix
local function sub_bytes(state)
for i = 1, 4 do
for j = 1, 4 do
Expand All @@ -466,7 +446,7 @@ local function sub_bytes(state)
end

--- ShiftRows transformation
--- @param state AESGCMState 4x4 state matrix
--- @param state AESState 4x4 state matrix
local function shift_rows(state)
-- Row 1: no shift
-- Row 2: shift left by 1
Expand All @@ -493,14 +473,14 @@ local function shift_rows(state)
end

--- AddRoundKey transformation
--- @param state AESGCMState 4x4 state matrix
--- @param state AESState 4x4 state matrix
--- @param round_key table Round key words
--- @param round integer Round number
local function add_round_key(state, round_key, round)
for c = 1, 4 do
local key_word = round_key[round * 4 + c]
for r = 1, 4 do
state[r][c] = bit32_bxor(state[r][c], key_word[r])
state[r][c] = bit32_raw_bxor(state[r][c], key_word[r])
end
end
end
Expand Down Expand Up @@ -552,6 +532,21 @@ end
-- GCM MODE IMPLEMENTATION
-- ============================================================================

--- Initialize a 16-element GCM block with zeros
--- @return AESBlock block Initialized block
local function create_gcm_block()
local arr = {}
for i = 1, 16 do
arr[i] = 0
end
--- @cast arr AESBlock
return arr
end

-- Pre-allocated arrays for gcm_multiply() to avoid repeated allocation
local gcm_z = create_gcm_block()
local gcm_v = create_gcm_block()

--- GCM field multiplication
--- @param x string 16-byte block
--- @param y string 16-byte block
Expand All @@ -570,27 +565,27 @@ local function gcm_multiply(x, y)
for i = 1, 16 do
local byte = string_byte(x, i)
for bit = 7, 0, -1 do
if bit32_band(byte, bit32_lshift(1, bit)) ~= 0 then
if bit32_raw_band(byte, bit32_raw_lshift(1, bit)) ~= 0 then
-- z = z XOR v
for j = 1, 16 do
z[j] = bit32_bxor(z[j], v[j])
z[j] = bit32_raw_bxor(z[j], v[j])
end
end

-- Check if LSB of v is 1 (bit 0 of last byte)
local lsb = bit32_band(v[16], 1)
local lsb = bit32_raw_band(v[16], 1)

-- v = v >> 1 (right shift entire 128-bit value by 1)
local carry = 0
for j = 1, 16 do
local new_carry = bit32_band(v[j], 1)
v[j] = bit32_bor(bit32_rshift(v[j], 1), bit32_lshift(carry, 7))
local new_carry = bit32_raw_band(v[j], 1)
v[j] = bit32_raw_bor(bit32_raw_rshift(v[j], 1), bit32_raw_lshift(carry, 7))
carry = new_carry
end

-- If LSB was 1, XOR with R = 0xE1000000000000000000000000000000
if lsb ~= 0 then
v[1] = bit32_bxor(v[1], 0xE1)
v[1] = bit32_raw_bxor(v[1], 0xE1)
end
end
end
Expand All @@ -617,7 +612,7 @@ local function ghash(h, data)
-- y = (y XOR block) * h
local y_xor = ""
for j = 1, 16 do
y_xor = y_xor .. string_char(bit32_bxor(string_byte(y, j), string_byte(block, j)))
y_xor = y_xor .. string_char(bit32_raw_bxor(string_byte(y, j), string_byte(block, j)))
end

y = gcm_multiply(y_xor, h)
Expand All @@ -642,7 +637,7 @@ local function inc_counter(counter)

-- Convert back to bytes (big-endian)
for i = 3, 0, -1 do
result = result .. string_char(bit32_band(bit32_rshift(val, i * 8), 0xFF))
result = result .. string_char(bit32_raw_band(bit32_raw_rshift(val, i * 8), 0xFF))
end

return result
Expand Down Expand Up @@ -762,7 +757,7 @@ function aes_gcm.encrypt(key, nonce, plaintext, aad)
local keystream = generate_keystream(key, nonce, #plaintext)
local ciphertext = ""
for i = 1, #plaintext do
ciphertext = ciphertext .. string_char(bit32_bxor(string_byte(plaintext, i), string_byte(keystream, i)))
ciphertext = ciphertext .. string_char(bit32_raw_bxor(string_byte(plaintext, i), string_byte(keystream, i)))
end

-- Calculate authentication tag
Expand All @@ -773,7 +768,7 @@ function aes_gcm.encrypt(key, nonce, plaintext, aad)
local encrypted_j0 = aes_encrypt_block(j0, expanded_key, nr)
local tag = ""
for i = 1, 16 do
tag = tag .. string_char(bit32_bxor(string_byte(s, i), string_byte(encrypted_j0, i)))
tag = tag .. string_char(bit32_raw_bxor(string_byte(s, i), string_byte(encrypted_j0, i)))
end

return ciphertext .. tag
Expand Down Expand Up @@ -841,7 +836,7 @@ function aes_gcm.decrypt(key, nonce, ciphertext_and_tag, aad)
local encrypted_j0 = aes_encrypt_block(j0, expanded_key, nr)
local expected_tag = ""
for i = 1, 16 do
expected_tag = expected_tag .. string_char(bit32_bxor(string_byte(s, i), string_byte(encrypted_j0, i)))
expected_tag = expected_tag .. string_char(bit32_raw_bxor(string_byte(s, i), string_byte(encrypted_j0, i)))
end

-- Verify tag (constant-time comparison)
Expand All @@ -853,7 +848,7 @@ function aes_gcm.decrypt(key, nonce, ciphertext_and_tag, aad)
local keystream = generate_keystream(key, nonce, #ciphertext)
local plaintext = ""
for i = 1, #ciphertext do
plaintext = plaintext .. string_char(bit32_bxor(string_byte(ciphertext, i), string_byte(keystream, i)))
plaintext = plaintext .. string_char(bit32_raw_bxor(string_byte(ciphertext, i), string_byte(keystream, i)))
end

return plaintext
Expand Down
Loading
Loading