-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgrpc_client.lua
More file actions
287 lines (234 loc) · 8.54 KB
/
grpc_client.lua
File metadata and controls
287 lines (234 loc) · 8.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
-- Support for Lua 5.2 and below
require("compat53")
local socket = require("cqueues.socket")
local headers = require("http.headers")
local h2_connection = require("http.h2_connection")
local pb = require("pb")
local grpc_client = {}
---@class grpc_client.Client
---@field conn grpc_client.h2.Conn
---@field loop grpc_client.cqueues.Loop
local Client = {}
---Create a new gRPC client that connects to the socket specified with `sock_args`.
---See `socket.connect` in the cqueues manual for more information.
---
---@nodiscard
---@param sock_args table A table of named arguments from `cqueues.socket.connect`
---@return grpc_client.Client
function grpc_client.new(sock_args)
local sock = socket.connect(sock_args)
sock:connect()
local conn = h2_connection.new(sock, "client")
conn:connect()
---@type grpc_client.Client
local ret = {
conn = conn,
loop = require("cqueues").new(),
}
setmetatable(ret, { __index = Client })
return ret
end
---Encodes the given `data` as the protobuf `type`.
---
---@param type string The absolute protobuf type
---@param data table The table of data, conforming to its protobuf definition
---@return string bytes The encoded bytes
local function encode(type, data)
local success, obj = pcall(pb.encode, type, data)
if not success then
error("failed to encode `" .. type .. "`: " .. obj)
end
local encoded_protobuf = obj
-- The packed flag; one byte, 0 if not packed, 1 if packed.
local packed_prefix = string.pack("I1", 0)
-- The payload length as a 4-byte big-endian integer
local payload_len = string.pack(">I4", encoded_protobuf:len())
local body = packed_prefix .. payload_len .. encoded_protobuf
return body
end
---Creates headers for a gRPC request.
---
---@param service string The desired service
---@param method string The desired method within the service
local function create_request_headers(service, method)
local req_headers = headers.new()
req_headers:append(":method", "POST")
req_headers:append(":scheme", "http")
req_headers:append(":path", "/" .. service .. "/" .. method)
req_headers:append("te", "trailers")
req_headers:append("content-type", "application/grpc")
return req_headers
end
---Perform a unary request.
---
---@nodiscard
---
---@param request_specifier grpc_client.RequestSpecifier
---@param data table The message to send. This should be in the structure of `request_specifier.request`.
---
---@return table|nil response The response as a table in the structure of `request_specifier.response`, or `nil` if there as an error.
---@return string|nil error An error string, if any.
function Client:unary_request(request_specifier, data)
local stream = self.conn:new_stream()
local service = request_specifier.service
local method = request_specifier.method
local request_type = request_specifier.request
local response_type = request_specifier.response
local body = encode(request_type, data)
stream:write_headers(create_request_headers(service, method), false)
stream:write_chunk(body, true)
local headers = stream:get_headers()
local grpc_status = headers:get("grpc-status")
if grpc_status then
local grpc_status = tonumber(grpc_status)
if grpc_status ~= 0 then
local err_name = require("grpc_client.status").name(grpc_status)
local grpc_msg = headers:get("grpc-message")
local grpc_msg = grpc_msg and (", msg = " .. grpc_msg) or ""
local err_str = "error from response: code = " .. (err_name or "unknown grpc status code") .. grpc_msg
return nil, err_str
end
end
local response_body = stream:get_next_chunk()
local trailers = stream:get_headers()
if trailers then -- idk if im big dummy or not but there are never any trailers
for name, value, never_index in trailers:each() do
print(name, value, never_index)
end
end
stream:shutdown()
-- string:sub(6) to skip the 1-byte compressed flag and the 4-byte message length
local response = pb.decode(response_type, response_body:sub(6))
return response, nil
end
---Perform a server-streaming request.
---
---`callback` will be called with every streamed response.
---
---@nodiscard
---
---@param request_specifier grpc_client.RequestSpecifier
---@param data table The message to send. This should be in the structure of `request_specifier.request`.
---@param callback fun(response: table) A callback that will be run with every response
---
---@return string|nil error An error string, if any.
function Client:server_streaming_request(request_specifier, data, callback)
local stream = self.conn:new_stream()
local service = request_specifier.service
local method = request_specifier.method
local request_type = request_specifier.request
local response_type = request_specifier.response
local body = encode(request_type, data)
stream:write_headers(create_request_headers(service, method), false)
stream:write_chunk(body, true)
local headers = stream:get_headers()
local grpc_status = headers:get("grpc-status")
if grpc_status then
local grpc_status = tonumber(grpc_status)
if grpc_status ~= 0 then
local err_name = require("grpc_client.status").name(grpc_status)
local err_str = "error from response: " .. (err_name or "unknown grpc status code")
return err_str
end
end
self.loop:wrap(function()
for response_body in stream:each_chunk() do
while response_body:len() > 0 do
local msg_len = string.unpack(">I4", response_body:sub(2, 5))
-- Skip the 1-byte compressed flag and the 4-byte message length
local body = response_body:sub(6, 6 + msg_len - 1)
---@diagnostic disable-next-line: redefined-local
local success, obj = pcall(pb.decode, response_type, body)
if not success then
print(obj)
os.exit(1)
end
local response = obj
callback(response)
response_body = response_body:sub(msg_len + 6)
end
end
local trailers = stream:get_headers()
if trailers then
for name, value, never_index in trailers:each() do
print(name, value, never_index)
end
end
end)
return nil
end
---Perform a bidirectional-streaming request.
---
---`callback` will be called with every streamed response.
---
---The raw client-to-server stream is returned to allow you to send encoded messages.
---
---@nodiscard
---
---@param request_specifier grpc_client.RequestSpecifier
---@param callback fun(response: table, stream: grpc_client.h2.Stream) A callback that will be run with every response
---
---@return grpc_client.h2.Stream|nil
---@return string|nil error An error string, if any.
function Client:bidirectional_streaming_request(request_specifier, callback)
local stream = self.conn:new_stream()
local service = request_specifier.service
local method = request_specifier.method
local response_type = request_specifier.response
stream:write_headers(create_request_headers(service, method), false)
local headers = stream:get_headers()
local grpc_status = headers:get("grpc-status")
if grpc_status then
local grpc_status = tonumber(grpc_status)
if grpc_status ~= 0 then
local err_name = require("grpc_client.status").name(grpc_status)
local err_str = "error from response: " .. (err_name or "unknown grpc status code")
return nil, err_str
end
end
self.loop:wrap(function()
for response_body in stream:each_chunk() do
while response_body:len() > 0 do
local msg_len = string.unpack(">I4", response_body:sub(2, 5))
-- Skip the 1-byte compressed flag and the 4-byte message length
local body = response_body:sub(6, 6 + msg_len - 1)
---@diagnostic disable-next-line: redefined-local
local success, obj = pcall(pb.decode, response_type, body)
if not success then
print(obj)
os.exit(1)
end
local response = obj
callback(response, stream)
response_body = response_body:sub(msg_len + 6)
end
end
local trailers = stream:get_headers()
if trailers then
for name, value, never_index in trailers:each() do
print(name, value, never_index)
end
end
end)
return stream, nil
end
return grpc_client
-- Definitions
---@class grpc_client.h2.Conn
---@field new_stream fun(self: self): grpc_client.h2.Stream
---@field ping fun(self: self, timeout_secs: integer)
---@class grpc_client.cqueues.Loop
---@field loop function
---@field wrap fun(self: self, fn: function)
---@class grpc_client.h2.Stream
---@field write_chunk function
---@field shutdown function
---@field write_headers function
---@field get_headers function
---@field get_next_chunk function
---@field each_chunk function
---@class grpc_client.RequestSpecifier
---@field service string The fully-qualified service name
---@field method string The method name
---@field request string The fully-qualified request type
---@field response string The fully-qualified response type