-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathjwt_parser.lua
More file actions
279 lines (240 loc) · 8.03 KB
/
jwt_parser.lua
File metadata and controls
279 lines (240 loc) · 8.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
-- JWT verification module
-- Adapted version of x25/luajwt for Kong. It provides various improvements and
-- an OOP architecture allowing the JWT to be parsed and verified separatly,
-- avoiding multiple parsings.
--
-- @see https://github.com/x25/luajwt
local json = require "cjson"
local utils = require "kong.tools.utils"
local crypto = require "crypto"
local asn_sequence = require "kong.plugins.jwt.asn_sequence"
local error = error
local type = type
local pcall = pcall
local ngx_time = ngx.time
local string_rep = string.rep
local string_sub = string.sub
local table_concat = table.concat
local setmetatable = setmetatable
local encode_base64 = ngx.encode_base64
local decode_base64 = ngx.decode_base64
--- Supported algorithms for signing tokens.
local alg_sign = {
["HS256"] = function(data, key) return crypto.hmac.digest("sha256", data, key, true) end,
--["HS384"] = function(data, key) return crypto.hmac.digest("sha384", data, key, true) end,
--["HS512"] = function(data, key) return crypto.hmac.digest("sha512", data, key, true) end
["RS256"] = function(data, key) return crypto.sign('sha256', data, crypto.pkey.from_pem(key, true)) end,
["RS512"] = function(data, key) return crypto.sign('sha512', data, crypto.pkey.from_pem(key, true)) end,
["ES256"] = function(data, key)
local pkeyPrivate = crypto.pkey.from_pem(key, true)
local signature = crypto.sign('sha256', data, pkeyPrivate)
local derSequence = asn_sequence.parse_simple_sequence(signature)
local r = asn_sequence.unsign_integer(derSequence[1], 32)
local s = asn_sequence.unsign_integer(derSequence[2], 32)
assert(#r == 32)
assert(#s == 32)
return r .. s
end
}
--- Supported algorithms for verifying tokens.
local alg_verify = {
["HS256"] = function(data, signature, key) return signature == alg_sign["HS256"](data, key) end,
--["HS384"] = function(data, signature, key) return signature == alg_sign["HS384"](data, key) end,
--["HS512"] = function(data, signature, key) return signature == alg_sign["HS512"](data, key) end
["RS256"] = function(data, signature, key)
local pkey = assert(crypto.pkey.from_pem(key), "Consumer Public Key is Invalid")
return crypto.verify('sha256', data, signature, pkey)
end,
["RS512"] = function(data, signature, key)
local pkey = assert(crypto.pkey.from_pem(key), "Consumer Public Key is Invalid")
return crypto.verify('sha512', data, signature, pkey)
end,
["ES256"] = function(data, signature, key)
local pkey = assert(crypto.pkey.from_pem(key), "Consumer Public Key is Invalid")
assert(#signature == 64, "Signature must be 64 bytes.")
local asn = {}
asn[1] = asn_sequence.resign_integer(string_sub(signature, 1, 32))
asn[2] = asn_sequence.resign_integer(string_sub(signature, 33, 64))
local signatureAsn = asn_sequence.create_simple_sequence(asn)
return crypto.verify('sha256', data, signatureAsn, pkey)
end
}
--- base 64 encoding
-- @param input String to base64 encode
-- @return Base64 encoded string
local function b64_encode(input)
local result = encode_base64(input)
result = result:gsub("+", "-"):gsub("/", "_"):gsub("=", "")
return result
end
--- base 64 decode
-- @param input String to base64 decode
-- @return Base64 decoded string
local function b64_decode(input)
local remainder = #input % 4
if remainder > 0 then
local padlen = 4 - remainder
input = input .. string_rep('=', padlen)
end
input = input:gsub("-", "+"):gsub("_", "/")
return decode_base64(input)
end
--- Tokenize a string by delimiter
-- Used to separate the header, claims and signature part of a JWT
-- @param str String to tokenize
-- @param div Delimiter
-- @param len Number of parts to retrieve
-- @return A table of strings
local function tokenize(str, div, len)
local result, pos = {}, 0
for st, sp in function() return str:find(div, pos, true) end do
result[#result + 1] = str:sub(pos, st-1)
pos = sp + 1
len = len - 1
if len <= 1 then
break
end
end
result[#result + 1] = str:sub(pos)
return result
end
--- Parse a JWT
-- Parse a JWT and validate header values.
-- @param token JWT to parse
-- @return A table containing base64 and decoded headers, claims and signature
local function decode_token(token)
-- Get b64 parts
local header_64, claims_64, signature_64 = unpack(tokenize(token, ".", 3))
-- Decode JSON
local ok, header, claims, signature = pcall(function()
return json.decode(b64_decode(header_64)),
json.decode(b64_decode(claims_64)),
b64_decode(signature_64)
end)
if not ok then
return nil, "invalid JSON"
end
if header.typ and header.typ:upper() ~= "JWT" then
return nil, "invalid typ"
end
if not header.alg or type(header.alg) ~= "string" or not alg_verify[header.alg] then
return nil, "invalid alg"
end
if not claims then
return nil, "invalid claims"
end
if not signature then
return nil, "invalid signature"
end
return {
token = token,
header_64 = header_64,
claims_64 = claims_64,
signature_64 = signature_64,
header = header,
claims = claims,
signature = signature
}
end
-- For test purposes
local function encode_token(data, key, alg, header)
if type(data) ~= "table" then
error("Argument #1 must be table", 2)
end
if type(key) ~= "string" then
error("Argument #2 must be string", 2)
end
if header and type(header) ~= "table" then
error("Argument #4 must be a table", 2)
end
alg = alg or "HS256"
if not alg_sign[alg] then
error("Algorithm not supported", 2)
end
local header = header or {typ = "JWT", alg = alg}
local segments = {
b64_encode(json.encode(header)),
b64_encode(json.encode(data))
}
local signing_input = table_concat(segments, ".")
local signature = alg_sign[alg](signing_input, key)
segments[#segments+1] = b64_encode(signature)
return table_concat(segments, ".")
end
--[[
JWT public interface
]]--
local _M = {}
_M.__index = _M
--- Instanciate a JWT parser
-- Parse a JWT and instanciate a JWT parser for further operations
-- Return errors instead of an instance if any encountered
-- @param token JWT to parse
-- @return JWT parser
-- @return error if any
function _M:new(token)
if type(token) ~= "string" then
error("Token must be a string, got " .. tostring(token), 2)
end
local token, err = decode_token(token)
if err then
return nil, err
end
return setmetatable(token, _M)
end
--- Verify a JWT signature
-- Verify the current JWT signature against a given key
-- @param key Key against which to verify the signature
-- @return A boolean indicating if the signature if verified or not
function _M:verify_signature(key)
return alg_verify[self.header.alg](self.header_64 .. "." .. self.claims_64, self.signature, key)
end
function _M:b64_decode(input)
return b64_decode(input)
end
--- Registered claims according to RFC 7519 Section 4.1
local registered_claims = {
["nbf"] = {
type = "number",
check = function(nbf)
if nbf > ngx_time() then
return "token not valid yet"
end
end
},
["exp"] = {
type = "number",
check = function(exp)
if exp <= ngx_time() then
return "token expired"
end
end
}
}
--- Verify registered claims (according to RFC 7519 Section 4.1)
-- Claims are verified by type and a check.
-- @param claims_to_verify A list of claims to verify.
-- @return A boolean indicating true if no errors zere found
-- @return A list of errors
function _M:verify_registered_claims(claims_to_verify)
if not claims_to_verify then
claims_to_verify = {}
end
local errors = nil
local claim, claim_rules
for _, claim_name in pairs(claims_to_verify) do
claim = self.claims[claim_name]
claim_rules = registered_claims[claim_name]
if type(claim) ~= claim_rules.type then
errors = utils.add_error(errors, claim_name, "must be a " .. claim_rules.type)
else
local check_err = claim_rules.check(claim)
if check_err then
errors = utils.add_error(errors, claim_name, check_err)
end
end
end
return errors == nil, errors
end
_M.encode = encode_token
return _M